From 878dba90609d7687d549985d88f0bc0059abd8f7 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Fri, 6 Feb 2026 21:02:12 +0000 Subject: [PATCH 01/52] add rabbit feedback Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 1982fee716..aaa210d4fb 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -564,10 +564,17 @@ def forward(self, input, *args, **kwargs): for name, module in name_to_module.items(): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: with enable_weight_access_and_writeback(module, model, name_to_module): +<<<<<<< HEAD module.hessian_helper = LocalHessianHelper(module, name) module.hessian_helper.setup() all_patched_modules.append((name, module)) if module.hessian_helper.is_enabled: +======= + module.local_hessian = LocalHessianHelper(module, name) + module.local_hessian.setup() + all_patched_modules.append((name, module)) + if module.local_hessian.is_enabled: +>>>>>>> e391ea1a (add rabbit feedback) weight_quantizers_info.append((name, module)) # Cache activations by running forward loop @@ -690,7 +697,11 @@ def quant_func(x, amax, quantizer=weight_quantizer): # Cleanup and free memory LocalHessianHelper.cache_mode = False for name, module in all_patched_modules: +<<<<<<< HEAD module.hessian_helper.cleanup() +======= + module.local_hessian.cleanup() +>>>>>>> e391ea1a (add rabbit feedback) print_rank_0("local_hessian: Calibration complete.") From 9545e2f01d22ab90cb6c273480756f73756c0430 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Thu, 12 Feb 2026 23:12:30 -0800 Subject: [PATCH 02/52] minor Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index aaa210d4fb..1982fee716 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -564,17 +564,10 @@ def forward(self, input, *args, **kwargs): for name, module in name_to_module.items(): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: with enable_weight_access_and_writeback(module, model, name_to_module): -<<<<<<< HEAD module.hessian_helper = LocalHessianHelper(module, name) module.hessian_helper.setup() all_patched_modules.append((name, module)) if module.hessian_helper.is_enabled: -======= - module.local_hessian = LocalHessianHelper(module, name) - module.local_hessian.setup() - all_patched_modules.append((name, module)) - if module.local_hessian.is_enabled: ->>>>>>> e391ea1a (add rabbit feedback) weight_quantizers_info.append((name, module)) # Cache activations by running forward loop @@ -697,11 +690,7 @@ def quant_func(x, amax, quantizer=weight_quantizer): # Cleanup and free memory LocalHessianHelper.cache_mode = False for name, module in all_patched_modules: -<<<<<<< HEAD module.hessian_helper.cleanup() -======= - module.local_hessian.cleanup() ->>>>>>> e391ea1a (add rabbit feedback) print_rank_0("local_hessian: Calibration complete.") From 7d56d641b47b92e601b9a7c6875611b39671e628 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 4 Feb 2026 18:20:37 +0000 Subject: [PATCH 03/52] tested perplexity Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 56 +++++++++++++++++++++++++++ modelopt/torch/quantization/mode.py | 14 +++++++ modelopt/torch/utils/network.py | 1 + 3 files changed, 71 insertions(+) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 7bb1e2322d..a48656cdc4 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -374,6 +374,18 @@ def find_quant_cfg_entry_by_path( "algorithm": "max", } +INT4_BLOCKWISE_WEIGHT_ONLY_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128}, "enable": True}, + "*input_quantizer": {"enable": False}, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq", + "use_sequential": True, + }, +} + INT4_AWQ_CFG = { "quant_cfg": [ @@ -1372,6 +1384,43 @@ class LocalHessianCalibConfig(QuantizeAlgorithmConfig): description="If True, module's local Hessian metadata will be kept as a module attribute.", ) +class GPTQConfig(QuantizeAlgorithmConfig): + """The config for GPTQ lite. + + GPTQ lite is a variant of GPTQ that does not exactly follow the official GPTQ implementation. + + GPTQ lite does not perform sequential quantization of layers. This means that the updated + activations are not used to process the next layer. + + The default values are taken from the official GPTQ implementation: + https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L35 + + Note: This feature is currently experimental and may not translate to improved accuracy as expected. + + + """ + + method: Literal["gptq"] = ModeloptField("gptq") + percdamp: float | None = ModeloptField( + default=0.01, + gt=0.0, + le=1.0, + title="Percentage damping factor.", + description="The percentage of average Hessian diagonal used for damping.", + ) + block_size: int | None = ModeloptField( + default=128, + title="Block size for GPTQ weight update.", + description="""The block size for GPTQ weight update, which must be a multiple of the + group_size used in the quantization.""", + ) + hessian_state_path: str | None = ModeloptField( + default=None, + title="Path to the Hessian state file.", + description="""The path to the Hessian state file. If hessian path exists, we load from + hessian file instead of recomputing them.""", + ) + class SmoothQuantCalibConfig(QuantizeAlgorithmConfig): """The config for ``smoothquant`` algorithm (SmoothQuant). @@ -1543,6 +1592,13 @@ class GPTQLiteConfig(QuantizeAlgorithmConfig): QuantizeQuantCfgType = list[QuantizerCfgEntry] +QuantizeQuantCfgType = dict[ + str | Callable, + QuantizerAttributeConfig + | list[QuantizerAttributeConfig] + | dict[str | Callable, QuantizerAttributeConfig | list[QuantizerAttributeConfig]], +] + _QuantizeAlgoCfgType = str | dict | QuantizeAlgorithmConfig | None QuantizeAlgoCfgType = _QuantizeAlgoCfgType | list[_QuantizeAlgoCfgType] | None diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index e08efece9a..88e93bb770 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -37,6 +37,7 @@ AWQFullCalibConfig, AWQLiteCalibConfig, CompressConfig, + GPTQConfig, GPTQLiteConfig, LocalHessianCalibConfig, MaxCalibConfig, @@ -59,6 +60,7 @@ ) from .model_calib import ( awq, + gptq, gptq_lite, local_hessian_calibrate, max_calibrate, @@ -502,3 +504,15 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]: return GPTQLiteConfig _calib_func = gptq_lite + + +@CalibrateModeRegistry.register_mode +class GPTQModeDescriptor(BaseCalibrateModeDescriptor): + """Mode for GPTQ calibration algorithm.""" + + @property + def config_class(self) -> type[QuantizeAlgorithmConfig]: + """Specifies the config class for the mode.""" + return GPTQConfig + + _calib_func = gptq diff --git a/modelopt/torch/utils/network.py b/modelopt/torch/utils/network.py index b54332375b..b07ca570c4 100644 --- a/modelopt/torch/utils/network.py +++ b/modelopt/torch/utils/network.py @@ -46,6 +46,7 @@ def _convert_to_wrapped_module_name(name: str) -> str: "ModelLike", "compare_dict", "create_param_grad_clear_hook", + "get_decoder_layers", "get_model_attributes", "get_module_device", "get_same_padding", From da41e3f72e183634a1b78d2535702123d57818f4 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 9 Feb 2026 16:46:47 +0000 Subject: [PATCH 04/52] tested, revert later Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 76 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index bb240ba0cb..ec7281998c 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -693,6 +693,82 @@ def export_quantized( "Unified HF export format does not specify inference tensor parallel or pipeline parallel. " "They will be set at deployment time." ) + if True: + # Disable quantizers + # mtq.fold_weight(full_model) + # print("Folded weights") + print("Disabling quantizers for perplexity evaluation (weights are already QDQ'ed)") + mtq.disable_quantizer(full_model, "*") + if True: + # mtq.fold_weight(full_model) + import os + + import torch.nn.functional as F + from datasets import load_dataset + from tqdm import trange + from transformers import AutoTokenizer + + # Set cache directory to work directory to avoid disk space issues + cache_dir = os.path.join( + os.path.dirname(os.path.abspath(__file__)), ".hf_cache" + ) + os.makedirs(cache_dir, exist_ok=True) + os.environ["HF_DATASETS_CACHE"] = cache_dir + print(f"Using HuggingFace datasets cache: {cache_dir}") + + def _get_wikitext2(tokenizer: AutoTokenizer, sequence_length: int): + test_dataset_raw = load_dataset( + "wikitext", "wikitext-2-raw-v1", split="test", cache_dir=cache_dir + ) + test_dataset_tok = tokenizer( + "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" + ).input_ids + num_test_sequences = test_dataset_tok.numel() // sequence_length + test_loader = [ + test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length] + for i in range(num_test_sequences) + ] + return test_loader + + @torch.no_grad() + def _compute_perplexity(model, data, batch_size: int = 1): + num_samples = len(data) + device = next(model.parameters()).device + # Running estimate of negative log-likelihood + nll_running = 0 + # Number of tokens processed to far + tokens_processed = 0 + # Loop through each batch + for i in trange( + 0, num_samples, batch_size, desc="Computing perplexity", leave=False + ): + j = min(i + batch_size, num_samples) + inputs = torch.cat(data[i:j]).to(device) + # Forward pass through the model + lm_logits = model(inputs).logits + # Shift logits and labels for next token prediction + shift_logits = lm_logits[:, :-1, :].contiguous() + shift_labels = inputs[:, 1:] + # Compute loss + loss = F.cross_entropy( + shift_logits.reshape(-1, shift_logits.size(-1)), + shift_labels.reshape(-1), + ) + # Calculate negative log likelihood + a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) + b = tokens_processed / (tokens_processed + shift_labels.numel()) + nll_running = a * loss + b * nll_running + # Update number of processed tokens + tokens_processed += shift_labels.numel() + # Compute perplexity + ppl = nll_running.exp().item() + return ppl + + eval_data = _get_wikitext2(tokenizer, 2048) + ppl = _compute_perplexity(full_model, eval_data) + print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") + + breakpoint() # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization From f98154ce3c08f9fd33579df5e6646d1c32acbc90 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 10 Feb 2026 04:41:46 +0000 Subject: [PATCH 05/52] tested Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 17 +++-- modelopt/torch/quantization/config.py | 94 +++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 5 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index ec7281998c..4ec18c2745 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -693,14 +693,16 @@ def export_quantized( "Unified HF export format does not specify inference tensor parallel or pipeline parallel. " "They will be set at deployment time." ) - if True: + if args.export_qdq_weights: # Disable quantizers - # mtq.fold_weight(full_model) - # print("Folded weights") + if "gptq" not in args.qformat: + mtq.fold_weight(full_model) + print("Folded weights") + print("Disabling quantizers for perplexity evaluation (weights are already QDQ'ed)") mtq.disable_quantizer(full_model, "*") + if True: - # mtq.fold_weight(full_model) import os import torch.nn.functional as F @@ -768,7 +770,6 @@ def _compute_perplexity(model, data, batch_size: int = 1): ppl = _compute_perplexity(full_model, eval_data) print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") - breakpoint() # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization @@ -1246,6 +1247,12 @@ def parse_args() -> argparse.Namespace: default=False, action="store_true", ) + parser.add_argument( + "--export_qdq_weights", + help=("Used for GPTQ weights as is without compressed weights for deployment."), + default=False, + action="store_true", + ) parser.add_argument( "--verbose", help="Print verbose output (e.g. quantization summary). Disable by --no-verbose.", diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index a48656cdc4..d41c210381 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -386,6 +386,100 @@ def find_quant_cfg_entry_by_path( }, } +NVFP4_STATIC_WO_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq", + "use_sequential": True, + }, +} + +NVFP4_STATIC_WO_GPTQ_LITE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} + +NVFP4_STATIC_WO_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "max", + "use_sequential": False, + }, +} + +NVFP4_STATIC_WO_GPTQ_LITE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} + +NVFP4_DYNAMIC_WO_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} INT4_AWQ_CFG = { "quant_cfg": [ From 1120b74b57ef6008ef73538a663b7506e4f29e20 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 11 Feb 2026 07:43:06 +0000 Subject: [PATCH 06/52] refactor Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 80 +++++----------------- 1 file changed, 16 insertions(+), 64 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 1982fee716..3f4710905f 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1632,56 +1632,6 @@ def prepare_hessian_inverse(h, weight, percdamp): return h_inv -def quantize_block(full_weight, block_start, block_end, h_inv, quantizer): - """Quantize a block of weights group by group (based on quantizer block sizes) with error propagation. - - Args: - full_weight: The full weight tensor (needed for INT4 quantization) - block_start: Starting column index of the block - block_end: Ending column index of the block - h_inv: Hessian inverse - quantizer: The quantizer to apply - Returns: - quantized_block: Quantized weights for this block - losses: Quantization losses per element - errors: Accumulated errors for propagation - """ - # Extract the block we're working on - block_weight = full_weight[:, block_start:block_end] - block_hinv = h_inv[block_start:block_end, block_start:block_end] - block_size = block_end - block_start - - quantized_block = torch.zeros_like(block_weight) - losses = torch.zeros_like(block_weight) - errors = torch.zeros_like(block_weight) - - # We perform column-wise update for GPTQ within the block - group_size = 1 - - for group_start in range(0, block_size, group_size): - group_end = min(group_start + group_size, block_size) - group_cols = slice(group_start, group_end) - # Get current column and its Hessian inverse diagonal - weight_col = block_weight[:, group_cols] - hinv_diag = torch.diag(block_hinv[group_cols, group_cols]) - - # Quantize using the full weight, then extract the columns we need - quantized_full = quantizer(full_weight) - quantized_cols = quantized_full[:, block_start + group_start : block_start + group_end] - quantized_block[:, group_cols] = quantized_cols - - # Compute quantization error and loss - error = (weight_col - quantized_cols) / hinv_diag - losses[:, group_cols] = (weight_col - quantized_cols) ** 2 / (hinv_diag**2) / 2 - errors[:, group_cols] = error - - # Propagate error to remaining columns in block - block_weight[:, group_start:] -= error @ block_hinv[group_start:group_end, group_start:] - full_weight[:, block_start:block_end] = block_weight - - return quantized_block, losses, errors - - def blockwise_weight_update(module, h, block_size, percdamp): """Update module weights using GPTQ-style blockwise quantization. @@ -1697,28 +1647,30 @@ def blockwise_weight_update(module, h, block_size, percdamp): # Preprocess Hessian: handle dead neurons and add damping h_inv = prepare_hessian_inverse(h, weight, percdamp) - # Initialize output tensors - quantized_weight = torch.zeros_like(weight) - losses = torch.zeros_like(weight) - # Process weights in blocks for block_start in range(0, num_cols, block_size): block_end = min(block_start + block_size, num_cols) - - quantized_block, block_losses, block_errors = quantize_block( - weight, block_start, block_end, h_inv, module.weight_quantizer - ) - # Store results - quantized_weight[:, block_start:block_end] = quantized_block - losses[:, block_start:block_end] = block_losses + n_cols = block_end - block_start + wblk = weight.clone() + errs = torch.zeros_like(wblk[:, block_start:block_end]) + h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end] + + for i in range(n_cols): + w_ci = wblk[:, block_start + i] + d = h_inv_cho_blk[i, i] + qdq = module.weight_quantizer(wblk) + weight[:, block_start + i] = qdq[:, block_start + i] + err = (w_ci - qdq[:, block_start + i]) / d + wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) + errs[:, i] = err # Propagate errors to remaining weights - weight[:, block_end:] -= block_errors @ h_inv[block_start:block_end, block_end:] + weight[:, block_end:].addmm_(errs, h_inv[block_start:block_end, block_end:], alpha=-1) # Print relative mse error - _print_relative_mse_error(quantized_weight, module.weight.float(), h, module.name) + _print_relative_mse_error(weight, module.weight.float(), h, module.name) # Update module weights - module.weight.data = quantized_weight.reshape(module.weight.shape).to(module.weight.data.dtype) + module.weight.data = weight.reshape(module.weight.shape).to(module.weight.data.dtype) def gptq_lite( From a16fcaeb46ae0898395da224ecca964fd94a28b9 Mon Sep 17 00:00:00 2001 From: realAsma <86726418+realAsma@users.noreply.github.com> Date: Fri, 6 Feb 2026 11:47:36 -0800 Subject: [PATCH 07/52] Track global_amax for weight FP4 MSE sweep; Refactor to NVFP4StaticQantizer, NVFP4MSECalibrator (#849) **Type of change:** ? **Overview:** ? ```python ``` - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No * **New Features** * Added NVFP4StaticQuantizer for improved 4-bit quantization with enhanced precision control * Introduced NVFP4MSECalibrator with flexible candidate generation for calibration optimization * **Improvements** * Optimized GPU kernels for Hopper+ graphics cards with better performance * Extended Triton support to broader GPU compatibility * Enhanced backward compatibility for restoring previously quantized models * **Tests** * Added comprehensive test coverage for new quantizers and calibration methods --------- Signed-off-by: realAsma Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/triton/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modelopt/torch/quantization/triton/__init__.py b/modelopt/torch/quantization/triton/__init__.py index def70e5914..6e8d4dba11 100644 --- a/modelopt/torch/quantization/triton/__init__.py +++ b/modelopt/torch/quantization/triton/__init__.py @@ -34,6 +34,10 @@ from .fp4_kernel import * from .fp8_kernel import * + # fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv) + if torch.cuda.get_device_capability() >= (8, 9): + from .fp4_kernel_hopper import * + # fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv) if torch.cuda.get_device_capability() >= (8, 9): from .fp4_kernel_hopper import * From 2771a9db7590265362009c3bb01484a4f30f4a09 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Fri, 6 Feb 2026 22:56:02 +0000 Subject: [PATCH 08/52] address reviewers feedback, delegate scaling factor calculation to NVFP4QTensor Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/export/quant_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 4ceb51cd2c..b762757cb9 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -360,9 +360,10 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> QUANTIZATION_NVFP4_SVDQUANT, QUANTIZATION_W4A8_NVFP4_FP8, ]: - # Calibrate weight quantizer if amax is not set - module_name = f"{type(module).__name__}.{weight_name}" - _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) + # Calibrate weight quantizer if amax is not set (only needed for dynamic quantizers) + if not is_nvfp4_static: + module_name = f"{type(module).__name__}.{weight_name}" + _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. From e73dc52b067815ddeca89a3bbab07325c88d6fb9 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 4 Feb 2026 18:20:37 +0000 Subject: [PATCH 09/52] tested perplexity Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 87 ++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 3f4710905f..3b3ace25d1 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1866,3 +1866,90 @@ def _layer_forward_loop(m, _inputs=layer_inputs): torch.cuda.empty_cache() finally: input_getter._unpatch_all_layers() + + print_rank_0("Sequential calibration completed successfully") + + +@torch.no_grad() +def gptq( + layer: nn.Module, + inputs: list[tuple[tuple, dict]], + percdamp: float = 0.01, + block_size: int = 128, + **kwargs, +): + """GPTQ quantization - a GPTQ variant.""" + import time + + total_start = time.time() + + # Dictionary to store hessian matrices for all linear layers in this decoder + hessian_state = {} + + # Phase 1: Build tensor mapping for all quantized linear layers in this decoder layer + tensor_mapping = {} + for name, module in layer.named_modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + in_features = module.weight.shape[-1] + tensor_mapping[name] = ((in_features, in_features), module.weight.device) + module.name = name # Attach name for easy access in hooks + + if not tensor_mapping: + print_rank_0("No quantized linear layers found in decoder layer, skipping GPTQ") + return + + # Initialize hessian state with zeros + for name, (shape, device) in tensor_mapping.items(): + hessian_state[name] = { + "hessian": torch.zeros(shape, dtype=torch.float32, device=device), + "n_samples": 0, + } + + # Phase 2: Register hooks to collect Hessians during forward passes + def hessian_hook(module, input, output): + """Hook to intercept activations and update hessian matrix.""" + state = hessian_state[module.name] + hessian, n_samples = update_hessian(input[0], state["hessian"], state["n_samples"]) + hessian_state[module.name] = {"hessian": hessian, "n_samples": n_samples} + + handles = [] + for name, module in layer.named_modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + handles.append(module.register_forward_hook(hessian_hook)) + + # Run forward passes with the provided inputs to collect Hessians + hessian_start = time.time() + print_rank_0( + f"Computing Hessians for {len(tensor_mapping)} linear layers using {len(inputs)} batches..." + ) + for args, kwargs_input in inputs: + layer(*args, **kwargs_input) + + # Remove hooks after collecting Hessians + for handle in handles: + handle.remove() + + torch.cuda.synchronize() if torch.cuda.is_available() else None + hessian_time = time.time() - hessian_start + + # Phase 3: Update weights using computed Hessians (same as gptq_lite) + weight_update_start = time.time() + print_rank_0("Updating weights using GPTQ algorithm...") + for name, module in layer.named_modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + state = hessian_state[module.name] + hessian = state["hessian"].to(module.weight.device) + blockwise_weight_update(module, hessian, block_size, percdamp) + # Free memory + del hessian_state[module.name] + torch.cuda.empty_cache() + + torch.cuda.synchronize() if torch.cuda.is_available() else None + weight_update_time = time.time() - weight_update_start + + total_time = time.time() - total_start + print_rank_0( + f"GPTQ timing - Hessian: {hessian_time:.2f}s, " + f"Weight update: {weight_update_time:.2f}s, " + f"Total: {total_time:.2f}s" + ) From da757d49fc9e1f9a6a646abb50f9b5b42527c143 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Thu, 12 Feb 2026 18:54:46 +0000 Subject: [PATCH 10/52] tested exported checkpoints on 0211 Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 69 ++++++++++++++++++++++ modelopt/torch/export/unified_export_hf.py | 4 +- modelopt/torch/quantization/config.py | 22 +++++++ modelopt/torch/quantization/model_calib.py | 57 +++++++++++++++++- 4 files changed, 147 insertions(+), 5 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 4ec18c2745..3da8856f17 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -693,6 +693,75 @@ def export_quantized( "Unified HF export format does not specify inference tensor parallel or pipeline parallel. " "They will be set at deployment time." ) + + if True: + import os + + import torch.nn.functional as F + from datasets import load_dataset + from tqdm import trange + from transformers import AutoTokenizer + + # Set cache directory to work directory to avoid disk space issues + cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".hf_cache") + os.makedirs(cache_dir, exist_ok=True) + os.environ["HF_DATASETS_CACHE"] = cache_dir + print(f"Using HuggingFace datasets cache: {cache_dir}") + + def _get_wikitext2(tokenizer: AutoTokenizer, sequence_length: int): + test_dataset_raw = load_dataset( + "wikitext", "wikitext-2-raw-v1", split="test", cache_dir=cache_dir + ) + test_dataset_tok = tokenizer( + "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" + ).input_ids + num_test_sequences = test_dataset_tok.numel() // sequence_length + test_loader = [ + test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length] + for i in range(num_test_sequences) + ] + return test_loader + + @torch.no_grad() + def _compute_perplexity(model, data, batch_size: int = 1): + num_samples = len(data) + device = next(model.parameters()).device + # Running estimate of negative log-likelihood + nll_running = 0 + # Number of tokens processed to far + tokens_processed = 0 + # Loop through each batch + for i in trange( + 0, num_samples, batch_size, desc="Computing perplexity", leave=False + ): + j = min(i + batch_size, num_samples) + inputs = torch.cat(data[i:j]).to(device) + # Forward pass through the model + lm_logits = model(inputs).logits + # Shift logits and labels for next token prediction + shift_logits = lm_logits[:, :-1, :].contiguous() + shift_labels = inputs[:, 1:] + # Compute loss + loss = F.cross_entropy( + shift_logits.reshape(-1, shift_logits.size(-1)), + shift_labels.reshape(-1), + ) + # Calculate negative log likelihood + a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) + b = tokens_processed / (tokens_processed + shift_labels.numel()) + nll_running = a * loss + b * nll_running + # Update number of processed tokens + tokens_processed += shift_labels.numel() + # Compute perplexity + ppl = nll_running.exp().item() + return ppl + + eval_data = _get_wikitext2(tokenizer, 2048) + ppl = _compute_perplexity(full_model, eval_data) + print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") + print(f"Saving model to {args.export_path}") + full_model.save_pretrained(args.export_path) + if args.export_qdq_weights: # Disable quantizers if "gptq" not in args.qformat: diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 4871d36b08..8a542b580d 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -559,7 +559,7 @@ def _export_quantized_weight( )[0] quantized_weight = to_quantized_weight( - weight.to(dtype), + weight.to(torch.bfloat16), weight_scale, quantization_format, weight_scale_2, @@ -576,7 +576,7 @@ def _export_quantized_weight( ) quantized_weight = to_quantized_weight( - weight.to(dtype), + weight.to(torch.bfloat16), weight_scale, quantization_format, weight_scale_2, diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index d41c210381..837bd29c2c 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -405,6 +405,28 @@ def find_quant_cfg_entry_by_path( }, } +NVFP4_STATIC_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq", + "use_sequential": True, + }, +} + NVFP4_STATIC_WO_GPTQ_LITE_CFG = { "quant_cfg": { "*weight_quantizer": { diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 3b3ace25d1..546e8ebf4f 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -15,6 +15,7 @@ """Calibration utilities.""" +import contextlib import math import os import warnings @@ -1820,6 +1821,56 @@ def hessian_hook(module, input, output): print_rank_0("GPTQ-lite quantization completed successfully") +def _set_input_quantizers_calib_mode(layer: nn.Module): + """Set all input quantizers of a layer to calibration mode.""" + for name, module in layer.named_modules(): + if ( + isinstance(module, TensorQuantizer) + and "input_quantizer" in name + and not module._disabled + and not module._dynamic + and module._calibrator is not None + ): + module._calibrator.reset() + module.disable_quant() + module.enable_calib() + + +def _set_input_quantizers_quant_mode(layer: nn.Module): + """Load fresh amaxes and restore all input quantizers of a layer to quant mode.""" + for name, module in layer.named_modules(): + if ( + isinstance(module, TensorQuantizer) + and "input_quantizer" in name + and not module._disabled + and not module._dynamic + and module._calibrator is not None + ): + if module._calibrator.compute_amax() is not None: + module.load_calib_amax() + module.enable_quant() + module.disable_calib() + + +@contextlib.contextmanager +def _disable_input_quantizers(layer: nn.Module): + """Temporarily disable all enabled input quantizers in a layer.""" + enabled_quantizers = [] + for name, module in layer.named_modules(): + if ( + isinstance(module, TensorQuantizer) + and "input_quantizer" in name + and not module._disabled + ): + module.disable() + enabled_quantizers.append(module) + try: + yield + finally: + for module in enabled_quantizers: + module.enable() + + @torch.no_grad() def sequential_calibrate( model: nn.Module, @@ -1867,8 +1918,6 @@ def _layer_forward_loop(m, _inputs=layer_inputs): finally: input_getter._unpatch_all_layers() - print_rank_0("Sequential calibration completed successfully") - @torch.no_grad() def gptq( @@ -1908,8 +1957,10 @@ def gptq( # Phase 2: Register hooks to collect Hessians during forward passes def hessian_hook(module, input, output): """Hook to intercept activations and update hessian matrix.""" + if hasattr(module, "input_quantizer") and module.input_quantizer.is_enabled: + inp = module.input_quantizer(input[0]) state = hessian_state[module.name] - hessian, n_samples = update_hessian(input[0], state["hessian"], state["n_samples"]) + hessian, n_samples = update_hessian(inp, state["hessian"], state["n_samples"]) hessian_state[module.name] = {"hessian": hessian, "n_samples": n_samples} handles = [] From 2fdcd22f5ac27d0af6cf70bca00865f15c15426d Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Fri, 13 Feb 2026 19:53:25 +0000 Subject: [PATCH 11/52] tested nano v3 Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 3da8856f17..449b7b9dc6 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -759,8 +759,7 @@ def _compute_perplexity(model, data, batch_size: int = 1): eval_data = _get_wikitext2(tokenizer, 2048) ppl = _compute_perplexity(full_model, eval_data) print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") - print(f"Saving model to {args.export_path}") - full_model.save_pretrained(args.export_path) + breakpoint() if args.export_qdq_weights: # Disable quantizers @@ -768,8 +767,8 @@ def _compute_perplexity(model, data, batch_size: int = 1): mtq.fold_weight(full_model) print("Folded weights") - print("Disabling quantizers for perplexity evaluation (weights are already QDQ'ed)") - mtq.disable_quantizer(full_model, "*") + print(f"Saving model to {args.export_path}") + full_model.save_pretrained(args.export_path) if True: import os From 49aa3d33c8f9b72d4195e15da27a632c20a22cc9 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 16 Feb 2026 02:48:11 +0000 Subject: [PATCH 12/52] added activation MSE logging Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 48 ++++++++++++++++++++++ modelopt/torch/quantization/__init__.py | 1 + modelopt/torch/quantization/model_calib.py | 2 + 3 files changed, 51 insertions(+) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 449b7b9dc6..53c1f2de19 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -15,6 +15,7 @@ import argparse import copy +import os import random import time import warnings @@ -583,6 +584,43 @@ def mono_quantize( else: calibrate_loop = create_forward_loop(dataloader=calib_dataloader) + # Phase 1: Collect pre-quantization activations (batch_size=1 to save memory) + if getattr(args, "measure_activation_mse", False): + mse_max_samples = getattr(args, "activation_mse_max_samples", 16) + mse_save_dir = getattr(args, "activation_mse_save_dir", None) + mse_input_path = getattr(args, "activation_mse_input_path", None) + + # Materialize or load a frozen set of MSE inputs so that the exact + # same samples are used across runs and across codebases. + if mse_input_path and os.path.isfile(mse_input_path): + mse_data = mtq.ActivationMSELogger.load_data(mse_input_path) + else: + from torch.utils.data import DataLoader as _DataLoader + + mse_dataloader = _DataLoader(calib_dataloader.dataset, batch_size=1, shuffle=False) + if mse_input_path: + mse_data = mtq.ActivationMSELogger.materialize_data( + mse_dataloader, + mse_input_path, + max_samples=mse_max_samples, + ) + else: + # No path given -- materialize in memory only + mse_data = [] + for i, batch in enumerate(mse_dataloader): + if i >= mse_max_samples: + break + t = batch["input_ids"] if isinstance(batch, dict) else batch + mse_data.append(t.cpu()) + + mse_logger = mtq.ActivationMSELogger( + max_samples=mse_max_samples, + layer_filter=getattr(args, "activation_mse_layer_filter", None), + save_dir=mse_save_dir, + ) + print("\n--- Phase 1: Collecting pre-quantization activations ---") + mse_logger.collect(language_model, mse_data, phase="original") + if calibration_only: language_model = mtq.calibrate( language_model, quant_cfg["algorithm"], forward_loop=calibrate_loop @@ -590,6 +628,16 @@ def mono_quantize( else: language_model = mtq.quantize(language_model, quant_cfg, forward_loop=calibrate_loop) + # Phase 2: Compute MSE against stored pre-quant activations + if getattr(args, "measure_activation_mse", False): + print("\n--- Phase 2: Computing per-layer activation MSE ---") + mse_logger.collect(language_model, mse_data, phase="quantized") + mse_logger.compute_mse() + print(mse_logger.summary()) + if mse_save_dir: + mse_logger.save() + del mse_logger, mse_data + # For VL models, update full_model to use the quantized language model if is_nemotron_vl_model: language_model_lineage = get_language_model_from_vl(full_model) diff --git a/modelopt/torch/quantization/__init__.py b/modelopt/torch/quantization/__init__.py index 87dbf30bb5..757b844fb1 100644 --- a/modelopt/torch/quantization/__init__.py +++ b/modelopt/torch/quantization/__init__.py @@ -19,6 +19,7 @@ from . import mode, plugins, utils # Add methods to mtq namespace +from .activation_mse import ActivationMSELogger, collect_activations, measure_activation_mse from .compress import * from .config import * from .conversion import * diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 546e8ebf4f..9b95dd0b16 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1959,6 +1959,8 @@ def hessian_hook(module, input, output): """Hook to intercept activations and update hessian matrix.""" if hasattr(module, "input_quantizer") and module.input_quantizer.is_enabled: inp = module.input_quantizer(input[0]) + else: + inp = input[0] state = hessian_state[module.name] hessian, n_samples = update_hessian(inp, state["hessian"], state["n_samples"]) hessian_state[module.name] = {"hessian": hessian, "n_samples": n_samples} From 1cf4b1a22c94b3da124acf45a25d914016bd5a68 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 17 Feb 2026 06:07:59 +0000 Subject: [PATCH 13/52] super v3 run Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 69 +++++++++ modelopt/torch/quantization/model_calib.py | 137 +++++++++++++++++- .../nn/modules/tensor_quantizer.py | 13 +- 3 files changed, 209 insertions(+), 10 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 837bd29c2c..a174fb3291 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -242,6 +242,54 @@ def find_quant_cfg_entry_by_path( {"quantizer_name": "*o_proj*", "enable": False}, # Skip QKV Output Projection ] +SUPER_NVFP4_CONSERVATIVE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + "*mixer.in_proj*": {"enable": False}, # Skip mamba linear + "*mixer.out_proj*": {"enable": False}, # Skip mamba linear + }, + "algorithm": "max", +} + +SUPER_NVFP4_CONSERVATIVE_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + "*mixer.in_proj*": {"enable": False}, # Skip mamba linear + "*mixer.out_proj*": {"enable": False}, # Skip mamba linear + }, + "algorithm": { + "method": "gptq", + "use_sequential": True, + }, +} + + INT8_DEFAULT_CFG = { "quant_cfg": [ *_base_disable_all, @@ -398,6 +446,9 @@ def find_quant_cfg_entry_by_path( "enable": False, }, **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + "*mixer.in_proj*": {"enable": False}, # Skip mamba linear + "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, "algorithm": { "method": "gptq", @@ -420,6 +471,9 @@ def find_quant_cfg_entry_by_path( "enable": True, }, **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + "*mixer.in_proj*": {"enable": False}, # Skip mamba linear + "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, "algorithm": { "method": "gptq", @@ -1355,6 +1409,21 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): ), ) + checkpoint_every_n_layers: int | None = ModeloptField( + default=None, + title="Save intermediate checkpoint every N layers during sequential calibration.", + ) + + checkpoint_dir: str | None = ModeloptField( + default=None, + title="Directory for saving/loading intermediate GPTQ checkpoints.", + ) + + resume_from_layer: int = ModeloptField( + default=0, + title="Layer index to resume sequential calibration from (0 = start from beginning).", + ) + class MaxCalibConfig(QuantizeAlgorithmConfig): """The config for max calibration algorithm. diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 9b95dd0b16..18a98e0c31 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -16,6 +16,8 @@ """Calibration utilities.""" import contextlib +import datetime +import json import math import os import warnings @@ -1551,7 +1553,13 @@ def postprocess(module, name): max_calibrate(model, forward_loop) -def _print_relative_mse_error(q: torch.Tensor, w: torch.Tensor, h: torch.Tensor, module_name: str): +def _print_relative_mse_error( + q: torch.Tensor, + w: torch.Tensor, + h: torch.Tensor, + module_name: str, + n_samples: int | None = None, +): """Print relative mean squared error between quantized and original weights. Computes the Hessian-weighted relative MSE between quantized and original weights, @@ -1563,13 +1571,15 @@ def _print_relative_mse_error(q: torch.Tensor, w: torch.Tensor, h: torch.Tensor, w (torch.Tensor): Original weight tensor h (torch.Tensor): Hessian matrix used for weighting the error module_name (str): Name of the module for logging purposes + n_samples (int | None): Number of Hessian samples (batches) used for this layer Note: Implementation adapted from the GPTQ repository: https://github.com/IST-DASLab/FP-Quant """ delta = q - w mse = (delta).mm(h).mul(delta).mean() / (w.mm(h).mul(w).mean() + 1e-6) - print(f"[{module_name}] Relative MSE error: {mse.item():.2e}") + suffix = f", n_hessian_samples: {n_samples}" if n_samples is not None else "" + print(f"[{module_name}] Relative MSE error: {mse.item():.2e}{suffix}") def update_hessian(input, hessian, n_samples): @@ -1582,15 +1592,15 @@ def update_hessian(input, hessian, n_samples): Returns: Tuple of (updated_hessian, new_sample_count) """ - batch_size = input.shape[0] + # Flatten to 2D (total_tokens, features) first, so batch_size counts tokens + input_flat = input.reshape(-1, input.shape[-1]).t().float() + batch_size = input_flat.shape[1] # Incremental averaging: scale down old hessian hessian *= n_samples / (n_samples + batch_size) n_samples += batch_size # Compute outer product: H += (2/n_samples) * X @ X^T - # where X is the flattened input reshaped to (features, batch*seq) - input_flat = input.reshape(-1, input.shape[-1]).t().float() scaled_input = math.sqrt(2 / n_samples) * input_flat hessian.add_((scaled_input @ scaled_input.t()).to(hessian.device)) @@ -1633,7 +1643,7 @@ def prepare_hessian_inverse(h, weight, percdamp): return h_inv -def blockwise_weight_update(module, h, block_size, percdamp): +def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None): """Update module weights using GPTQ-style blockwise quantization. Args: @@ -1641,6 +1651,7 @@ def blockwise_weight_update(module, h, block_size, percdamp): H: Hessian matrix (d x d) block_size: Size of blocks to process at once percdamp: Damping percentage for Hessian diagonal + n_samples: Number of Hessian samples for logging (optional) """ weight = module.weight.data.float().clone() _, num_cols = weight.shape @@ -1669,7 +1680,7 @@ def blockwise_weight_update(module, h, block_size, percdamp): weight[:, block_end:].addmm_(errs, h_inv[block_start:block_end, block_end:], alpha=-1) # Print relative mse error - _print_relative_mse_error(weight, module.weight.float(), h, module.name) + _print_relative_mse_error(weight, module.weight.float(), h, module.name, n_samples) # Update module weights module.weight.data = weight.reshape(module.weight.shape).to(module.weight.data.dtype) @@ -1871,11 +1882,117 @@ def _disable_input_quantizers(layer: nn.Module): module.enable() +def save_fake_checkpoint(model: nn.Module, output_dir: str) -> None: + """Save fake quant checkpoint using save_pretrained() (HuggingFace format). + + Args: + model: The quantized model to save. + output_dir: Directory to write the checkpoint into. + """ + from modelopt.torch.opt.conversion import ModeloptStateManager, modelopt_state + from modelopt.torch.quantization.conversion import quantizer_state as get_quantizer_state + + os.makedirs(output_dir, exist_ok=True) + + # Remove accelerate hooks before saving to avoid pickling errors in modelopt_state. + # Accelerate hooks contain local functions (closures like 'add_hook_to_module..new_forward') + # that can't be pickled. Even after removing hooks from modules, they may still be captured + # in closures within quantizer_state metadata when modelopt_state() calls update_last_state_before_save(). + try: + from accelerate.hooks import remove_hook_from_module + + remove_hook_from_module(model, recurse=True) + except ImportError: + pass + + # Save model weights first (without modelopt_state to avoid pickling error) + model.save_pretrained(output_dir, save_modelopt_state=False) + + # Manually save modelopt_state after removing hooks and rebuilding quantizer_state. + # We need to rebuild quantizer_state because hooks may have been captured in closures + # when quantizer_state() was called during update_last_state_before_save() inside modelopt_state(). + if ModeloptStateManager.is_converted(model): + modelopt_state_path = os.path.join(output_dir, "modelopt_state.pth") + state = modelopt_state(model) + + # Rebuild quantizer_state in metadata to remove any hook references captured in closures + if "modelopt_state_dict" in state and isinstance(state["modelopt_state_dict"], list): + cleaned_state_dict = [] + for entry in state["modelopt_state_dict"]: + if isinstance(entry, tuple) and len(entry) >= 2: + mode_str, state_dict_entry = entry[0], entry[1] + if isinstance(state_dict_entry, dict) and "metadata" in state_dict_entry: + # Rebuild quantizer_state after hooks are removed + cleaned_entry = state_dict_entry.copy() + cleaned_metadata = cleaned_entry["metadata"].copy() + cleaned_metadata["quantizer_state"] = get_quantizer_state(model) + cleaned_entry["metadata"] = cleaned_metadata + cleaned_state_dict.append((mode_str, cleaned_entry)) + else: + cleaned_state_dict.append(entry) + else: + cleaned_state_dict.append(entry) + state["modelopt_state_dict"] = cleaned_state_dict + + torch.save(state, modelopt_state_path) + print_rank_0(f"Saved ModelOpt state to {modelopt_state_path}") + + +def _save_gptq_checkpoint( + model: nn.Module, checkpoint_dir: str, last_layer_idx: int, total_layers: int +) -> None: + """Save intermediate GPTQ checkpoint with metadata for resume support. + + Saves accelerate hooks before calling save_fake_checkpoint (which removes them), + then re-attaches them so the model remains functional for subsequent layers. + """ + print_rank_0( + f"Saving GPTQ checkpoint after layer {last_layer_idx}/{total_layers - 1} to {checkpoint_dir}" + ) + + # Save accelerate hooks before save_fake_checkpoint removes them. + # We need to re-attach them after saving so the model keeps working. + saved_hooks = {} + for name, module in model.named_modules(): + if hasattr(module, "_hf_hook"): + saved_hooks[name] = module._hf_hook + + try: + save_fake_checkpoint(model, checkpoint_dir) + finally: + # Re-attach accelerate hooks so the model keeps working for remaining layers. + if saved_hooks: + try: + from accelerate.hooks import add_hook_to_module + + name_to_module = dict(model.named_modules()) + for name, hook in saved_hooks.items(): + if name in name_to_module: + add_hook_to_module(name_to_module[name], hook) + print_rank_0(f"Re-attached {len(saved_hooks)} accelerate hooks") + except ImportError: + pass + + # Save checkpoint metadata for resume support. + meta = { + "last_completed_layer": last_layer_idx, + "total_layers": total_layers, + "timestamp": datetime.datetime.now().isoformat(), + } + meta_path = os.path.join(checkpoint_dir, "gptq_checkpoint_meta.json") + with open(meta_path, "w") as f: + json.dump(meta, f, indent=2) + print_rank_0(f"GPTQ checkpoint saved (layer {last_layer_idx}/{total_layers - 1})") + + @torch.no_grad() def sequential_calibrate( model: nn.Module, forward_loop: ForwardLoop, calib_func: Callable, + checkpoint_every_n_layers: int | None = None, + checkpoint_dir: str | None = None, + resume_from_layer: int = 0, **calib_kwargs, ): """Sequential calibration - a sequential layer-by-layer calibration algorithm. @@ -1917,6 +2034,8 @@ def _layer_forward_loop(m, _inputs=layer_inputs): torch.cuda.empty_cache() finally: input_getter._unpatch_all_layers() + + print_rank_0("Sequential calibration completed") @torch.no_grad() @@ -1992,7 +2111,9 @@ def hessian_hook(module, input, output): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: state = hessian_state[module.name] hessian = state["hessian"].to(module.weight.device) - blockwise_weight_update(module, hessian, block_size, percdamp) + blockwise_weight_update( + module, hessian, block_size, percdamp, n_samples=state["n_samples"] + ) # Free memory del hessian_state[module.name] torch.cuda.empty_cache() diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 3ff7401ec3..a62d8620b1 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -1346,10 +1346,19 @@ def global_amax(self, value): def _fake_quantize(self, inputs): """Fake quantization using two-level scaling with _amax and _global_amax.""" if self.amax is not None: + # Ensure amax/global_amax are on the same device as inputs. + # After from_pretrained with device_map, quantizer buffers may remain + # on CPU while model weights/activations are on GPU. + amax = self.amax + if amax.device != inputs.device: + amax = amax.to(inputs.device) + global_amax = self.global_amax + if global_amax is not None and global_amax.device != inputs.device: + global_amax = global_amax.to(inputs.device) return static_blockwise_fp4_fake_quant( inputs, - self.amax, - self.global_amax, # Can be None, will be computed internally + amax, + global_amax, # Can be None, will be computed internally True, # quantize_block_scales inputs.dtype, self._pass_through_bwd, From f9b1487a0ccc3ec1b856cb06d23997b95a129cd5 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 17 Feb 2026 21:17:38 +0000 Subject: [PATCH 14/52] added activationmse logging helper Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/activation_mse.py | 787 ++++++++++++++++++ 1 file changed, 787 insertions(+) create mode 100644 modelopt/torch/quantization/activation_mse.py diff --git a/modelopt/torch/quantization/activation_mse.py b/modelopt/torch/quantization/activation_mse.py new file mode 100644 index 0000000000..df90c84a3a --- /dev/null +++ b/modelopt/torch/quantization/activation_mse.py @@ -0,0 +1,787 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Per-layer activation MSE measurement for quantization analysis. + +This module provides utilities to measure per-linear-layer MSE between a model's +activations before and after quantization. Inspired by FP-Quant's two-phase approach: + +- **Phase 1** (before quantization): ``collect_activations()`` runs the model on + calibration data and stores per-layer outputs in CPU RAM. +- **Phase 2** (after quantization): ``measure_activation_mse()`` runs the quantized + model on the same data and computes MSE on-the-fly against the stored Phase 1 + outputs. Only running scalar accumulators are kept -- no second set of tensors + is stored. + +Typical usage in hf_ptq.py:: + + # Phase 1: before quantization + orig_acts = mtq.collect_activations(model, mse_dataloader, max_samples=16) + + # Quantize + model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + + # Phase 2: after quantization -- computes MSE incrementally + mse = mtq.measure_activation_mse(model, mse_dataloader, orig_acts, max_samples=16) +""" + +import contextlib +import fnmatch +import hashlib +import os +from collections.abc import Iterable +from datetime import datetime + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + +from modelopt.torch.utils.network import get_decoder_layers + +__all__ = ["ActivationMSELogger", "collect_activations", "measure_activation_mse"] + + +def _tensor_from_output(out) -> torch.Tensor: + """Extract a single tensor from a layer's output (handles tuple returns).""" + if isinstance(out, torch.Tensor): + return out.detach() + return out[0].detach() + + +def _is_linear(module: nn.Module) -> bool: + """Check if a module is a linear layer (covers both nn.Linear and quantized linear).""" + return isinstance(module, nn.Linear) + + +def _matches_filter(name: str, layer_filter: str | None) -> bool: + """Check if a layer name matches the optional filter pattern (fnmatch-style).""" + if layer_filter is None: + return True + return fnmatch.fnmatch(name, layer_filter) + + +def _discover_target_layers( + model: nn.Module, + layer_filter: str | None = None, +) -> dict[str, nn.Module]: + """Discover linear layers within decoder blocks of the model. + + Uses get_decoder_layers() to find transformer blocks, then finds all linear + submodules within those blocks. Falls back to all linear layers in the model + if decoder blocks cannot be identified. + + Args: + model: The model to inspect. + layer_filter: Optional fnmatch pattern to select specific layers + (e.g., ``"*self_attn*"``). + + Returns: + Dict mapping full module path -> module reference. + """ + decoder_layers = get_decoder_layers(model) + + targets: dict[str, nn.Module] = {} + + if decoder_layers is not None: + # Build a reverse lookup: module id -> full name in model + module_to_name: dict[int, str] = {id(m): n for n, m in model.named_modules()} + + for block in decoder_layers: + block_name = module_to_name.get(id(block), "") + for sub_name, sub_mod in block.named_modules(): + if _is_linear(sub_mod): + full_name = f"{block_name}.{sub_name}" if block_name else sub_name + if _matches_filter(full_name, layer_filter): + targets[full_name] = sub_mod + else: + # Fallback: scan all modules + for name, module in model.named_modules(): + if _is_linear(module): + if _matches_filter(name, layer_filter): + targets[name] = module + + return targets + + +def _run_batch(model: nn.Module, batch) -> None: + """Run a single batch through the model.""" + if isinstance(batch, dict): + model(**batch) + elif isinstance(batch, (list, tuple)): + model(*batch) + else: + model(batch) + + +@torch.no_grad() +def collect_activations( + model: nn.Module, + dataloader: Iterable, + max_samples: int | None = None, + layer_filter: str | None = None, +) -> dict[str, list[torch.Tensor]]: + """Collect per-linear-layer output activations into CPU memory (Phase 1). + + Registers forward hooks on linear layers within the model's decoder blocks, + runs calibration data through the model, and returns captured per-layer outputs. + + Args: + model: The model to collect activations from (typically pre-quantization). + dataloader: An iterable yielding batches (dicts with ``input_ids``, etc.). + Use batch_size=1 to minimize memory. + max_samples: Maximum number of batches to process. ``None`` means all. + layer_filter: Optional fnmatch pattern to restrict which layers are + collected (e.g., ``"*self_attn*"``). ``None`` means all linear layers + inside decoder blocks. + + Returns: + Dict mapping layer name to a list of output tensors (one per batch, on CPU). + """ + was_training = model.training + model.eval() + + # Discover target linear layers + targets = _discover_target_layers(model, layer_filter) + if not targets: + raise ValueError( + f"No linear layers found matching the given filter. layer_filter={layer_filter!r}" + ) + + print(f"Collecting activations for {len(targets)} layers...") + + # Storage: {layer_name: [tensor_per_batch, ...]} + saved: dict[str, list[torch.Tensor]] = {name: [] for name in targets} + captured: dict[str, torch.Tensor] = {} + + def _make_hook(key: str): + def hook(_module: nn.Module, _input, output) -> None: + captured[key] = _tensor_from_output(output).cpu() + + return hook + + # Register hooks + hooks = [] + for name, module in targets.items(): + hooks.append(module.register_forward_hook(_make_hook(name))) + + try: + n_batches = 0 + for batch in tqdm(dataloader, desc="Collecting activations", leave=False): + if max_samples is not None and n_batches >= max_samples: + break + + captured.clear() + _run_batch(model, batch) + + for name in targets: + if name in captured: + saved[name].append(captured[name]) + + n_batches += 1 + finally: + for h in hooks: + h.remove() + + model.train(was_training) + + print(f"Collected {n_batches} samples across {len(targets)} layers") + return saved + + +@torch.no_grad() +def measure_activation_mse( + model: nn.Module, + dataloader: Iterable, + orig_activations: dict[str, list[torch.Tensor]], + max_samples: int | None = None, + layer_filter: str | None = None, +) -> dict[str, float]: + """Compute per-layer MSE between stored and live activations (Phase 2). + + Runs the (quantized) model on calibration data and computes MSE on-the-fly + against the pre-quantization activations stored by :func:`collect_activations`. + + Only scalar accumulators (sum of squared errors and element count) are kept + per layer -- no second set of activation tensors is stored. + + The MSE for each layer is computed as:: + + MSE = sum_over_all_elements((orig - quant) ^ 2) / total_elements + + Args: + model: The quantized model to measure. + dataloader: Same dataloader used for :func:`collect_activations` + (must yield batches in the same order). + orig_activations: Output of :func:`collect_activations` -- dict mapping + layer name to a list of pre-quantization output tensors. + max_samples: Maximum number of batches to process (should match Phase 1). + layer_filter: Optional fnmatch pattern (should match Phase 1). + + Returns: + Dict mapping layer name to its MSE value. + """ + was_training = model.training + model.eval() + + # Discover target layers on the (now-quantized) model + targets = _discover_target_layers(model, layer_filter) + + # Only measure layers that exist in both the model and orig_activations + common_keys = sorted(set(targets.keys()) & set(orig_activations.keys())) + if not common_keys: + raise ValueError( + "No matching layers between the quantized model and stored activations. " + "Ensure the same layer_filter is used for both phases." + ) + + skipped = set(orig_activations.keys()) - set(targets.keys()) + if skipped: + print(f"Warning: {len(skipped)} layers in orig_activations not found in model (skipped)") + + print(f"Computing activation MSE for {len(common_keys)} layers...") + + # Scalar accumulators + sum_sq: dict[str, float] = dict.fromkeys(common_keys, 0.0) + count: dict[str, int] = dict.fromkeys(common_keys, 0) + + captured: dict[str, torch.Tensor] = {} + + def _make_hook(key: str): + def hook(_module: nn.Module, _input, output) -> None: + captured[key] = _tensor_from_output(output).cpu() + + return hook + + # Register hooks only on common layers + hooks = [targets[name].register_forward_hook(_make_hook(name)) for name in common_keys] + + try: + batch_idx = 0 + for batch in tqdm(dataloader, desc="Computing activation MSE", leave=False): + if max_samples is not None and batch_idx >= max_samples: + break + + captured.clear() + _run_batch(model, batch) + + for name in common_keys: + if name not in captured: + continue + if batch_idx >= len(orig_activations.get(name, [])): + continue + + o = orig_activations[name][batch_idx].float() + q = captured[name].float() + + if o.shape != q.shape: + print( + f"Warning: shape mismatch for {name} batch {batch_idx}: " + f"{o.shape} vs {q.shape}, skipping" + ) + continue + + sum_sq[name] += F.mse_loss(o, q, reduction="sum").item() + count[name] += o.numel() + + batch_idx += 1 + finally: + for h in hooks: + h.remove() + + model.train(was_training) + + mse = { + key: (sum_sq[key] / count[key]) if count[key] > 0 else float("nan") for key in common_keys + } + + return mse + + +# --------------------------------------------------------------------------- +# Portable ActivationMSELogger class +# --------------------------------------------------------------------------- + + +def _portable_discover_target_layers( + model: nn.Module, + layer_filter: str | None = None, +) -> dict[str, nn.Module]: + """Discover linear layers in decoder blocks with a portable fallback chain. + + Strategy: + 1. Try modelopt's ``get_decoder_layers`` (available inside ModelOpt). + 2. Try common HuggingFace attribute paths (``model.model.layers``, etc.). + 3. Fall back to scanning **all** ``nn.Linear`` in ``model.named_modules()``. + + Within each set of decoder blocks the function collects every ``nn.Linear`` + sub-module and optionally filters by *layer_filter* (fnmatch pattern). + """ + decoder_layers = None + + # 1. Try modelopt helper (may not exist when file is copied elsewhere) + with contextlib.suppress(Exception): + decoder_layers = get_decoder_layers(model) + + # 2. Try common HF / other patterns + if decoder_layers is None: + for attr_chain in ( + ("model", "layers"), + ("decoder", "layers"), + ("transformer", "h"), + ("backbone", "layers"), + ): + obj = model + try: + for attr in attr_chain: + obj = getattr(obj, attr) + if isinstance(obj, nn.ModuleList): + decoder_layers = obj + break + except AttributeError: + continue + + targets: dict[str, nn.Module] = {} + + if decoder_layers is not None: + module_to_name: dict[int, str] = {id(m): n for n, m in model.named_modules()} + for block in decoder_layers: + block_name = module_to_name.get(id(block), "") + for sub_name, sub_mod in block.named_modules(): + if isinstance(sub_mod, nn.Linear): + full_name = f"{block_name}.{sub_name}" if block_name else sub_name + if _matches_filter(full_name, layer_filter): + targets[full_name] = sub_mod + else: + # 3. Fallback: all linear layers + for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + if _matches_filter(name, layer_filter): + targets[name] = module + + return targets + + +class ActivationMSELogger: + """Portable activation MSE logger for comparing original vs quantized models. + + Works with both: + + - ``List[Tensor]`` data (**FP-Quant** style): each element is ``[1, seq_len]`` + or ``[B, seq_len]``, consumed via ``model(tensor)``. + - ``DataLoader`` / ``Iterable`` yielding dicts (**ModelOpt** style): + ``{"input_ids": tensor, ...}``, consumed via ``model(**batch)``. + + Guarantees same samples are used for both phases via SHA-256 hashing of + input tensors. Supports saving / loading all activations to disk for + later cross-codebase comparison. + + Example (ModelOpt -- DataLoader with dict batches):: + + mse_logger = ActivationMSELogger(max_samples=16, save_dir="./mse_logs") + mse_logger.collect(model, dataloader, phase="original") + model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + mse_logger.collect(model, dataloader, phase="quantized") + results = mse_logger.compute_mse() + print(mse_logger.summary()) + mse_logger.save() + + Example (FP-Quant -- List[Tensor]):: + + mse_logger = ActivationMSELogger(max_samples=16, save_dir="./mse_logs") + mse_logger.collect(model_orig, calibration_data, phase="original") + mse_logger.collect(model_quant, calibration_data, phase="quantized") + results = mse_logger.compute_mse() + print(mse_logger.summary()) + mse_logger.save() + """ + + def __init__( + self, + max_samples: int = 16, + layer_filter: str | None = None, + save_dir: str | None = None, + ): + """Initialize the ActivationMSELogger. + + Args: + max_samples: Maximum number of calibration batches to process per phase. + layer_filter: Optional glob pattern to restrict which layers are tracked. + save_dir: Optional directory path for persisting activation data to disk. + """ + self.max_samples = max_samples + self.layer_filter = layer_filter + self.save_dir = save_dir + + # Per-phase state + self.original_activations: dict[str, list[torch.Tensor]] = {} + self.quantized_activations: dict[str, list[torch.Tensor]] = {} + self.input_hashes: list[str] = [] # hashes for "original" phase + self.quant_input_hashes: list[str] = [] # hashes for "quantized" phase + + # Computed after both phases + self.mse_results: dict[str, float] | None = None + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + @torch.no_grad() + def collect( + self, + model: nn.Module, + data: Iterable, + phase: str, + target_modules: dict[str, nn.Module] | None = None, + ) -> None: + """Collect per-linear-layer output activations for a given phase. + + Args: + model: The model to run (original or quantized). + data: An iterable of batches. Each batch can be: + + - ``torch.Tensor`` with shape ``[B, seq_len]`` (FP-Quant style). + - ``dict`` with at least an ``"input_ids"`` key (ModelOpt style). + - ``list`` / ``tuple`` of tensors. + phase: ``"original"`` or ``"quantized"``. + target_modules: Optional explicit mapping of ``{name: nn.Module}`` + to attach hooks to. If *None*, layers are auto-discovered + via decoder-block scanning. + """ + if phase not in ("original", "quantized"): + raise ValueError(f"phase must be 'original' or 'quantized', got {phase!r}") + + was_training = model.training + model.eval() + + # ----- layer discovery ----- + targets = ( + target_modules + if target_modules is not None + else (_portable_discover_target_layers(model, self.layer_filter)) + ) + if not targets: + raise ValueError( + "No linear layers found. Provide target_modules explicitly or " + f"check layer_filter={self.layer_filter!r}." + ) + + print( + f"[ActivationMSELogger] Phase '{phase}': hooking {len(targets)} layers, " + f"max_samples={self.max_samples}" + ) + + # ----- storage ----- + saved: dict[str, list[torch.Tensor]] = {name: [] for name in targets} + captured: dict[str, torch.Tensor] = {} + hashes: list[str] = [] + + def _make_hook(key: str): + def hook(_module: nn.Module, _input, output) -> None: + captured[key] = _tensor_from_output(output).cpu() + + return hook + + hooks = [] + for name, module in targets.items(): + hooks.append(module.register_forward_hook(_make_hook(name))) + + try: + n_batches = 0 + for batch in tqdm(data, desc=f"Collecting ({phase})", leave=False): + if self.max_samples is not None and n_batches >= self.max_samples: + break + + captured.clear() + self._run_batch(model, batch) + + for name in targets: + if name in captured: + saved[name].append(captured[name]) + + hashes.append(self._hash_batch(batch)) + n_batches += 1 + finally: + for h in hooks: + h.remove() + + model.train(was_training) + + # ----- store results on self ----- + if phase == "original": + self.original_activations = saved + self.input_hashes = hashes + else: + self.quantized_activations = saved + self.quant_input_hashes = hashes + # Verify sample consistency + if self.input_hashes: + self._verify_hashes() + + # Invalidate any previous MSE since we have new activations + self.mse_results = None + + print(f"[ActivationMSELogger] Collected {n_batches} batches for phase '{phase}'") + + def compute_mse(self) -> dict[str, float]: + """Compute per-layer MSE between original and quantized activations. + + Returns: + Dict mapping layer name to its MSE value. + + Raises: + ValueError: If either phase has not been collected yet. + """ + if not self.original_activations: + raise ValueError( + "No original activations collected. Call collect(..., phase='original') first." + ) + if not self.quantized_activations: + raise ValueError( + "No quantized activations collected. Call collect(..., phase='quantized') first." + ) + + common_keys = sorted( + set(self.original_activations.keys()) & set(self.quantized_activations.keys()) + ) + if not common_keys: + raise ValueError( + "No matching layer names between original and quantized activations. " + "Ensure the same model architecture / layer_filter is used for both phases." + ) + + orig_only = set(self.original_activations.keys()) - set(self.quantized_activations.keys()) + quant_only = set(self.quantized_activations.keys()) - set(self.original_activations.keys()) + if orig_only: + print( + f"[ActivationMSELogger] Warning: {len(orig_only)} layers only in original (skipped)" + ) + if quant_only: + print( + f"[ActivationMSELogger] Warning: {len(quant_only)} layers only in quantized (skipped)" + ) + + sum_sq: dict[str, float] = dict.fromkeys(common_keys, 0.0) + count: dict[str, int] = dict.fromkeys(common_keys, 0) + + for name in common_keys: + orig_list = self.original_activations[name] + quant_list = self.quantized_activations[name] + n = min(len(orig_list), len(quant_list)) + for i in range(n): + o = orig_list[i].float() + q = quant_list[i].float() + if o.shape != q.shape: + print( + f"[ActivationMSELogger] Warning: shape mismatch for {name} " + f"batch {i}: {o.shape} vs {q.shape}, skipping" + ) + continue + sum_sq[name] += F.mse_loss(o, q, reduction="sum").item() + count[name] += o.numel() + + self.mse_results = { + key: (sum_sq[key] / count[key]) if count[key] > 0 else float("nan") + for key in common_keys + } + return self.mse_results + + def save(self, path: str | None = None) -> str: + """Save all state (activations, hashes, MSE) to disk via ``torch.save``. + + Args: + path: Explicit file path. If *None*, a timestamped file is created + inside ``self.save_dir`` (which must be set). + + Returns: + The path where the file was saved. + """ + if path is None: + if self.save_dir is None: + raise ValueError("Provide a path or set save_dir in the constructor.") + os.makedirs(self.save_dir, exist_ok=True) + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + path = os.path.join(self.save_dir, f"activation_mse_{ts}.pt") + + payload = { + "max_samples": self.max_samples, + "layer_filter": self.layer_filter, + "input_hashes": self.input_hashes, + "quant_input_hashes": self.quant_input_hashes, + "original_activations": self.original_activations, + "quantized_activations": self.quantized_activations, + "mse": self.mse_results, + } + torch.save(payload, path) + print(f"[ActivationMSELogger] Saved to {path}") + return path + + @classmethod + def load(cls, path: str) -> "ActivationMSELogger": + """Load a previously saved ``ActivationMSELogger`` from disk. + + Args: + path: Path to the ``.pt`` file created by :meth:`save`. + + Returns: + A new ``ActivationMSELogger`` instance with restored state. + """ + payload = torch.load(path, map_location="cpu", weights_only=False) + logger = cls( + max_samples=payload.get("max_samples", 16), + layer_filter=payload.get("layer_filter"), + ) + logger.original_activations = payload.get("original_activations", {}) + logger.quantized_activations = payload.get("quantized_activations", {}) + logger.input_hashes = payload.get("input_hashes", []) + logger.quant_input_hashes = payload.get("quant_input_hashes", []) + logger.mse_results = payload.get("mse") + print(f"[ActivationMSELogger] Loaded from {path}") + return logger + + def summary(self) -> str: + """Return a formatted string summarising per-layer MSE results. + + Computes MSE first if not already done. + """ + if self.mse_results is None: + self.compute_mse() + assert self.mse_results is not None + + lines = ["Per-layer activation MSE (original vs quantized):"] + lines.extend( + f" {key}: {self.mse_results[key]:.6e}" for key in sorted(self.mse_results.keys()) + ) + return "\n".join(lines) + + # ------------------------------------------------------------------ + # Pre-materialized MSE data (cross-run / cross-codebase safety) + # ------------------------------------------------------------------ + + @staticmethod + def materialize_data( + data: Iterable, + path: str, + max_samples: int | None = None, + ) -> list[torch.Tensor]: + """Freeze the first *max_samples* batches from *data* into a ``.pt`` file. + + Each batch (``dict``, ``Tensor``, or ``list/tuple``) is normalised to a + single ``input_ids`` CPU tensor before saving. The resulting file is a + plain ``List[Tensor]`` that can be loaded in **any** codebase and passed + straight to :meth:`collect`. + + If *path* already exists it is **not** overwritten -- call + :meth:`load_data` instead. + + Args: + data: Iterable of batches (DataLoader, List[Tensor], etc.). + path: Destination ``.pt`` file path. + max_samples: How many batches to keep. ``None`` means all. + + Returns: + The materialised list of CPU tensors (same object that was saved). + """ + samples: list[torch.Tensor] = [] + for batch in data: + if max_samples is not None and len(samples) >= max_samples: + break + if isinstance(batch, dict): + t = batch.get("input_ids", next(iter(batch.values()))) + elif isinstance(batch, torch.Tensor): + t = batch + elif isinstance(batch, (list, tuple)): + t = batch[0] + else: + raise TypeError(f"Unsupported batch type: {type(batch)}") + samples.append(t.cpu()) + + os.makedirs(os.path.dirname(path) or ".", exist_ok=True) + torch.save(samples, path) + print(f"[ActivationMSELogger] Materialised {len(samples)} MSE input samples -> {path}") + return samples + + @staticmethod + def load_data(path: str) -> list[torch.Tensor]: + """Load a previously materialised MSE input set. + + Args: + path: Path to the ``.pt`` file created by :meth:`materialize_data`. + + Returns: + ``List[Tensor]`` of input batches (on CPU). + """ + samples = torch.load(path, map_location="cpu", weights_only=True) + print(f"[ActivationMSELogger] Loaded {len(samples)} MSE input samples from {path}") + return samples + + # ------------------------------------------------------------------ + # Static / private helpers + # ------------------------------------------------------------------ + + @staticmethod + def _run_batch(model: nn.Module, batch) -> None: + """Run a single batch through the model (handles Tensor, dict, list/tuple). + + Automatically moves inputs to the model's device so that CPU-stored + materialized data works transparently with a CUDA model. + """ + device = next(model.parameters()).device + if isinstance(batch, dict): + batch = { + k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() + } + model(**batch) + elif isinstance(batch, torch.Tensor): + model(batch.to(device)) + elif isinstance(batch, (list, tuple)): + batch = tuple(t.to(device) if isinstance(t, torch.Tensor) else t for t in batch) + model(*batch) + else: + raise TypeError(f"Unsupported batch type: {type(batch)}") + + @staticmethod + def _hash_batch(batch) -> str: + """Compute SHA-256 hash of the primary input tensor in *batch*. + + - ``dict`` -> hashes ``batch["input_ids"]`` (falls back to first value). + - ``Tensor`` -> hashes the tensor directly. + - ``list/tuple`` -> hashes the first element. + """ + if isinstance(batch, dict): + t = batch.get("input_ids", next(iter(batch.values()))) + elif isinstance(batch, torch.Tensor): + t = batch + elif isinstance(batch, (list, tuple)): + t = batch[0] if batch else None + else: + return "" + + if t is None or not isinstance(t, torch.Tensor): + return "" + return hashlib.sha256(t.cpu().contiguous().numpy().tobytes()).hexdigest() + + def _verify_hashes(self) -> None: + """Compare input hashes between original and quantized phases.""" + n = min(len(self.input_hashes), len(self.quant_input_hashes)) + mismatches = sum(1 for i in range(n) if self.input_hashes[i] != self.quant_input_hashes[i]) + if mismatches: + print( + f"[ActivationMSELogger] WARNING: {mismatches}/{n} batches have " + f"different input hashes between original and quantized phases. " + f"The same data may not have been used for both phases!" + ) + else: + print(f"[ActivationMSELogger] Input hash verification passed ({n}/{n} match)") From b62f3813fc61ca530816c15e93a4cf25033421ab Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Thu, 19 Feb 2026 23:47:48 +0000 Subject: [PATCH 15/52] input amax sync added + tested gptq super sft checkpoint Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 29 ++++++++ tests/gpu/torch/quantization/test_gptq.py | 87 ++++++++++++++++++++++ 2 files changed, 116 insertions(+) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 18a98e0c31..8b2ec97fd8 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1863,6 +1863,35 @@ def _set_input_quantizers_quant_mode(layer: nn.Module): module.disable_calib() +def _set_kv_quantizers_calib_mode(layer: nn.Module): + for name, module in layer.named_modules(): + if ( + isinstance(module, TensorQuantizer) + and ("k_bmm_quantizer" in name or "v_bmm_quantizer" in name) + and not module._disabled + and not module._dynamic + and module._calibrator is not None + ): + module._calibrator.reset() + module.disable_quant() + module.enable_calib() + + +def _set_kv_quantizers_quant_mode(layer: nn.Module): + for name, module in layer.named_modules(): + if ( + isinstance(module, TensorQuantizer) + and ("k_bmm_quantizer" in name or "v_bmm_quantizer" in name) + and not module._disabled + and not module._dynamic + and module._calibrator is not None + ): + if module._calibrator.compute_amax() is not None: + module.load_calib_amax() + module.enable_quant() + module.disable_calib() + + @contextlib.contextmanager def _disable_input_quantizers(layer: nn.Module): """Temporarily disable all enabled input quantizers in a layer.""" diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index d43177cae2..ec95e1e8d3 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -20,7 +20,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer import modelopt.torch.quantization as mtq +from modelopt.torch.export.unified_export_hf import _export_quantized_weight from modelopt.torch.quantization.model_calib import blockwise_weight_update, update_hessian +from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor from modelopt.torch.utils.dataset_utils import create_forward_loop, get_dataset_dataloader RAND_SEED = 42 @@ -156,6 +158,91 @@ def test_gptq_updates(block_size, dim, model_weight, expect_weight_change): assert torch.allclose(model.weight.data, q_dq_weight), "Weight should be equal" +def test_gptq_export_roundtrip(): + """Test that GPTQ export + dequantize produces weights matching in-memory QDQ.""" + torch.manual_seed(RAND_SEED) + dim = 128 + block_size = 4 + + # Step 1: Create a simple linear model and quantize to install NVFP4 quantizers + model = torch.nn.Linear(dim, dim).to("cuda") + model.name = "linear" + original_weight = model.weight.data.clone() + input_tensor = torch.randn(2, 16, dim).to("cuda") + quant_cfg = mtq.NVFP4_DEFAULT_CFG + + mtq.quantize(model, quant_cfg, forward_loop=lambda m: m(input_tensor)) + + # Restore original weight before GPTQ + model.weight.data = original_weight.clone() + + # Step 2: Perform GPTQ — compute Hessian and update weights + hessian = torch.zeros(dim, dim, dtype=torch.float32) + n_samples = 0 + hessian, n_samples = update_hessian(input_tensor, hessian, n_samples) + hessian = hessian.to("cuda") + + blockwise_weight_update(model, hessian, block_size, percdamp=0.1) + + # Save the QDQ reference from the quantizer applied to GPTQ'd weights + gptq_weight_shape = model.weight.data.shape + gptq_weight_dtype = model.weight.data.dtype + qdq_ref = model.weight.data.clone() + + # Step 3: Export — converts weight to packed NVFP4 and registers scale buffers + _export_quantized_weight(model, torch.bfloat16) + + # Verify export produced the expected buffers + assert hasattr(model, "weight_scale"), "Export should register weight_scale buffer" + assert hasattr(model, "weight_scale_2"), "Export should register weight_scale_2 buffer" + + # Step 4: Dequantize the exported packed weight and compare with QDQ reference + packed_weight = model.weight.data + weight_scale = model.weight_scale + weight_scale_2 = model.weight_scale_2 + + nvfp4_qtensor = NVFP4QTensor(gptq_weight_shape, gptq_weight_dtype, packed_weight) + deq_weight = nvfp4_qtensor.dequantize( + dtype=torch.bfloat16, + scale=weight_scale, + double_scale=weight_scale_2, + block_sizes={-1: 16}, + ) + + assert deq_weight.shape == qdq_ref.shape, ( + f"Shape mismatch: dequantized {deq_weight.shape} vs QDQ ref {qdq_ref.shape}" + ) + diff = (deq_weight - qdq_ref.to(torch.bfloat16)).abs() + max_diff = diff.max().item() + max_diff_idx = diff.argmax().item() + max_diff_row = max_diff_idx // deq_weight.shape[1] + max_diff_col = max_diff_idx % deq_weight.shape[1] + num_mismatched = (diff > 1e-3).sum().item() + total_elements = diff.numel() + + print("\n--- Diff Stats ---") + print(f" Max diff: {max_diff}") + print(f" Mean diff: {diff.mean().item()}") + print(f" Median diff: {diff.median().item()}") + print(f" Std diff: {diff.std().item()}") + print( + f" Mismatched (>1e-3): {num_mismatched}/{total_elements} " + f"({100 * num_mismatched / total_elements:.2f}%)" + ) + print( + f" Max diff at [{max_diff_row}, {max_diff_col}]: " + f"deq={deq_weight[max_diff_row, max_diff_col].item()}, " + f"qdq_ref={qdq_ref[max_diff_row, max_diff_col].item()}" + ) + + assert torch.allclose(deq_weight, qdq_ref.to(torch.bfloat16), atol=1e-2), ( + f"Dequantized weight does not match QDQ reference. " + f"Max diff: {max_diff} at [{max_diff_row}, {max_diff_col}] " + f"(deq={deq_weight[max_diff_row, max_diff_col].item()}, " + f"qdq_ref={qdq_ref[max_diff_row, max_diff_col].item()})" + ) + + @pytest.mark.parametrize( "quant_cfg", [mtq.NVFP4_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG] ) From 636c88a88ff5e0d7c71e10854074b982d59de154 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 23 Feb 2026 21:24:34 +0000 Subject: [PATCH 16/52] checkpoints generated on 0223 Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 1 - modelopt/torch/quantization/config.py | 9 ++-- modelopt/torch/quantization/model_calib.py | 57 ++++++++++++++++------ 3 files changed, 45 insertions(+), 22 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 53c1f2de19..5905829a23 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -886,7 +886,6 @@ def _compute_perplexity(model, data, batch_size: int = 1): ppl = _compute_perplexity(full_model, eval_data) print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") - # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization if args.vllm_fakequant_export: diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index a174fb3291..08344a2bc5 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -257,7 +257,7 @@ def find_quant_cfg_entry_by_path( "enable": True, }, **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, + **super_disabled_quantizer_cfg, "*mixer.in_proj*": {"enable": False}, # Skip mamba linear "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, @@ -279,7 +279,7 @@ def find_quant_cfg_entry_by_path( "enable": True, }, **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, + **super_disabled_quantizer_cfg, "*mixer.in_proj*": {"enable": False}, # Skip mamba linear "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, @@ -446,7 +446,7 @@ def find_quant_cfg_entry_by_path( "enable": False, }, **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, + # **_mamba_moe_disabled_quantizer_cfg, "*mixer.in_proj*": {"enable": False}, # Skip mamba linear "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, @@ -471,9 +471,6 @@ def find_quant_cfg_entry_by_path( "enable": True, }, **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, - "*mixer.in_proj*": {"enable": False}, # Skip mamba linear - "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, "algorithm": { "method": "gptq", diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 8b2ec97fd8..8c7feca670 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -137,6 +137,26 @@ def max_calibrate( if hasattr(module, "layer_sync_moe_local_experts_amax"): module.layer_sync_moe_local_experts_amax(sync_weight_amax=sync_expert_weight_amax) + for name, module in list(model.named_modules()): + if isinstance(module, TensorQuantizer) and not module._disabled: + if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): + # Get the initial amax from max calibration + initial_amax = module._amax.clone().detach() + + is_nvfp4_static = ( + module.is_static_block_quant + and module._num_bits == (2, 1) + and module._block_sizes is not None + and module._block_sizes.get("scale_bits") == (4, 3) + ) + + if is_nvfp4_static: + # Compute and set global_amax + global_amax = reduce_amax(initial_amax, axis=None) + + # Convert to NVFP4StaticQuantizer in-place + NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) + if not distributed_sync: return @@ -341,6 +361,7 @@ def mse_calibrate( if fp8_scale_sweep and is_nvfp4_static: # Replace calibrator with NVFP4MSECalibrator + print("mse_calibrate: Replacing calibrator with NVFP4MSECalibrator") module._calibrator = NVFP4MSECalibrator( amax=initial_amax, axis=module._calibrator._axis, @@ -627,6 +648,7 @@ def quant_func(x, amax, quantizer=weight_quantizer): error_func = helper.get_error_func() if fp8_scale_sweep and is_nvfp4_static: + print("local_hessian_calibrate: Replacing calibrator with NVFP4MSECalibrator") weight_quantizer._calibrator = NVFP4MSECalibrator( amax=initial_amax, axis=weight_quantizer._calibrator._axis if weight_quantizer._calibrator else None, @@ -2102,21 +2124,26 @@ def gptq( "n_samples": 0, } - # Phase 2: Register hooks to collect Hessians during forward passes - def hessian_hook(module, input, output): - """Hook to intercept activations and update hessian matrix.""" - if hasattr(module, "input_quantizer") and module.input_quantizer.is_enabled: - inp = module.input_quantizer(input[0]) - else: - inp = input[0] - state = hessian_state[module.name] - hessian, n_samples = update_hessian(inp, state["hessian"], state["n_samples"]) - hessian_state[module.name] = {"hessian": hessian, "n_samples": n_samples} + # Phase 2: Patch forwards to collect Hessians (similar to local_hessian_calibrate) + def _make_hessian_forward(module_name): + def hessian_forward(self, input, *args, **kwargs): + inp = input.to_local() if hasattr(input, "to_local") else input + state = hessian_state[module_name] + hessian, n_samples = update_hessian(inp, state["hessian"], state["n_samples"]) + hessian_state[module_name] = {"hessian": hessian, "n_samples": n_samples} + + self.weight_quantizer.disable() + out = self._forward_no_gptq_hessian(input, *args, **kwargs) + self.weight_quantizer.enable() + return out + + return hessian_forward - handles = [] + patched_modules = [] for name, module in layer.named_modules(): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - handles.append(module.register_forward_hook(hessian_hook)) + bind_forward_method(module, _make_hessian_forward(name), "_forward_no_gptq_hessian") + patched_modules.append(module) # Run forward passes with the provided inputs to collect Hessians hessian_start = time.time() @@ -2126,9 +2153,9 @@ def hessian_hook(module, input, output): for args, kwargs_input in inputs: layer(*args, **kwargs_input) - # Remove hooks after collecting Hessians - for handle in handles: - handle.remove() + # Unpatch forwards + for module in patched_modules: + unpatch_forward_method(module, "_forward_no_gptq_hessian") torch.cuda.synchronize() if torch.cuda.is_available() else None hessian_time = time.time() - hessian_start From 88c990884576962ed93fd3975069891350fe02a9 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 4 Feb 2026 18:20:37 +0000 Subject: [PATCH 17/52] tested perplexity Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 134 -------------- modelopt/torch/quantization/mode.py | 2 + modelopt/torch/quantization/model_calib.py | 205 +-------------------- 3 files changed, 5 insertions(+), 336 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 08344a2bc5..09a2bf559f 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -434,125 +434,6 @@ def find_quant_cfg_entry_by_path( }, } -NVFP4_STATIC_WO_GPTQ_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - # **_mamba_moe_disabled_quantizer_cfg, - "*mixer.in_proj*": {"enable": False}, # Skip mamba linear - "*mixer.out_proj*": {"enable": False}, # Skip mamba linear - }, - "algorithm": { - "method": "gptq", - "use_sequential": True, - }, -} - -NVFP4_STATIC_GPTQ_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq", - "use_sequential": True, - }, -} - -NVFP4_STATIC_WO_GPTQ_LITE_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq_lite", - "use_sequential": False, - }, -} - -NVFP4_STATIC_WO_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "max", - "use_sequential": False, - }, -} - -NVFP4_STATIC_WO_GPTQ_LITE_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq_lite", - "use_sequential": False, - }, -} - -NVFP4_DYNAMIC_WO_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq_lite", - "use_sequential": False, - }, -} INT4_AWQ_CFG = { "quant_cfg": [ @@ -1406,21 +1287,6 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): ), ) - checkpoint_every_n_layers: int | None = ModeloptField( - default=None, - title="Save intermediate checkpoint every N layers during sequential calibration.", - ) - - checkpoint_dir: str | None = ModeloptField( - default=None, - title="Directory for saving/loading intermediate GPTQ checkpoints.", - ) - - resume_from_layer: int = ModeloptField( - default=0, - title="Layer index to resume sequential calibration from (0 = start from beginning).", - ) - class MaxCalibConfig(QuantizeAlgorithmConfig): """The config for max calibration algorithm. diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 88e93bb770..efc66ffa94 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -255,6 +255,8 @@ def wrapped_calib_func( else: # Direct calibration (existing behavior) func(model, forward_loop=forward_loop, **kwargs) + else: + raise ValueError(f"No calibration function provided for method: {method}") # Lets get the latest metadata for the quantizer states metadata = {} diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 8c7feca670..24f5a19cd4 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1854,196 +1854,11 @@ def hessian_hook(module, input, output): print_rank_0("GPTQ-lite quantization completed successfully") -def _set_input_quantizers_calib_mode(layer: nn.Module): - """Set all input quantizers of a layer to calibration mode.""" - for name, module in layer.named_modules(): - if ( - isinstance(module, TensorQuantizer) - and "input_quantizer" in name - and not module._disabled - and not module._dynamic - and module._calibrator is not None - ): - module._calibrator.reset() - module.disable_quant() - module.enable_calib() - - -def _set_input_quantizers_quant_mode(layer: nn.Module): - """Load fresh amaxes and restore all input quantizers of a layer to quant mode.""" - for name, module in layer.named_modules(): - if ( - isinstance(module, TensorQuantizer) - and "input_quantizer" in name - and not module._disabled - and not module._dynamic - and module._calibrator is not None - ): - if module._calibrator.compute_amax() is not None: - module.load_calib_amax() - module.enable_quant() - module.disable_calib() - - -def _set_kv_quantizers_calib_mode(layer: nn.Module): - for name, module in layer.named_modules(): - if ( - isinstance(module, TensorQuantizer) - and ("k_bmm_quantizer" in name or "v_bmm_quantizer" in name) - and not module._disabled - and not module._dynamic - and module._calibrator is not None - ): - module._calibrator.reset() - module.disable_quant() - module.enable_calib() - - -def _set_kv_quantizers_quant_mode(layer: nn.Module): - for name, module in layer.named_modules(): - if ( - isinstance(module, TensorQuantizer) - and ("k_bmm_quantizer" in name or "v_bmm_quantizer" in name) - and not module._disabled - and not module._dynamic - and module._calibrator is not None - ): - if module._calibrator.compute_amax() is not None: - module.load_calib_amax() - module.enable_quant() - module.disable_calib() - - -@contextlib.contextmanager -def _disable_input_quantizers(layer: nn.Module): - """Temporarily disable all enabled input quantizers in a layer.""" - enabled_quantizers = [] - for name, module in layer.named_modules(): - if ( - isinstance(module, TensorQuantizer) - and "input_quantizer" in name - and not module._disabled - ): - module.disable() - enabled_quantizers.append(module) - try: - yield - finally: - for module in enabled_quantizers: - module.enable() - - -def save_fake_checkpoint(model: nn.Module, output_dir: str) -> None: - """Save fake quant checkpoint using save_pretrained() (HuggingFace format). - - Args: - model: The quantized model to save. - output_dir: Directory to write the checkpoint into. - """ - from modelopt.torch.opt.conversion import ModeloptStateManager, modelopt_state - from modelopt.torch.quantization.conversion import quantizer_state as get_quantizer_state - - os.makedirs(output_dir, exist_ok=True) - - # Remove accelerate hooks before saving to avoid pickling errors in modelopt_state. - # Accelerate hooks contain local functions (closures like 'add_hook_to_module..new_forward') - # that can't be pickled. Even after removing hooks from modules, they may still be captured - # in closures within quantizer_state metadata when modelopt_state() calls update_last_state_before_save(). - try: - from accelerate.hooks import remove_hook_from_module - - remove_hook_from_module(model, recurse=True) - except ImportError: - pass - - # Save model weights first (without modelopt_state to avoid pickling error) - model.save_pretrained(output_dir, save_modelopt_state=False) - - # Manually save modelopt_state after removing hooks and rebuilding quantizer_state. - # We need to rebuild quantizer_state because hooks may have been captured in closures - # when quantizer_state() was called during update_last_state_before_save() inside modelopt_state(). - if ModeloptStateManager.is_converted(model): - modelopt_state_path = os.path.join(output_dir, "modelopt_state.pth") - state = modelopt_state(model) - - # Rebuild quantizer_state in metadata to remove any hook references captured in closures - if "modelopt_state_dict" in state and isinstance(state["modelopt_state_dict"], list): - cleaned_state_dict = [] - for entry in state["modelopt_state_dict"]: - if isinstance(entry, tuple) and len(entry) >= 2: - mode_str, state_dict_entry = entry[0], entry[1] - if isinstance(state_dict_entry, dict) and "metadata" in state_dict_entry: - # Rebuild quantizer_state after hooks are removed - cleaned_entry = state_dict_entry.copy() - cleaned_metadata = cleaned_entry["metadata"].copy() - cleaned_metadata["quantizer_state"] = get_quantizer_state(model) - cleaned_entry["metadata"] = cleaned_metadata - cleaned_state_dict.append((mode_str, cleaned_entry)) - else: - cleaned_state_dict.append(entry) - else: - cleaned_state_dict.append(entry) - state["modelopt_state_dict"] = cleaned_state_dict - - torch.save(state, modelopt_state_path) - print_rank_0(f"Saved ModelOpt state to {modelopt_state_path}") - - -def _save_gptq_checkpoint( - model: nn.Module, checkpoint_dir: str, last_layer_idx: int, total_layers: int -) -> None: - """Save intermediate GPTQ checkpoint with metadata for resume support. - - Saves accelerate hooks before calling save_fake_checkpoint (which removes them), - then re-attaches them so the model remains functional for subsequent layers. - """ - print_rank_0( - f"Saving GPTQ checkpoint after layer {last_layer_idx}/{total_layers - 1} to {checkpoint_dir}" - ) - - # Save accelerate hooks before save_fake_checkpoint removes them. - # We need to re-attach them after saving so the model keeps working. - saved_hooks = {} - for name, module in model.named_modules(): - if hasattr(module, "_hf_hook"): - saved_hooks[name] = module._hf_hook - - try: - save_fake_checkpoint(model, checkpoint_dir) - finally: - # Re-attach accelerate hooks so the model keeps working for remaining layers. - if saved_hooks: - try: - from accelerate.hooks import add_hook_to_module - - name_to_module = dict(model.named_modules()) - for name, hook in saved_hooks.items(): - if name in name_to_module: - add_hook_to_module(name_to_module[name], hook) - print_rank_0(f"Re-attached {len(saved_hooks)} accelerate hooks") - except ImportError: - pass - - # Save checkpoint metadata for resume support. - meta = { - "last_completed_layer": last_layer_idx, - "total_layers": total_layers, - "timestamp": datetime.datetime.now().isoformat(), - } - meta_path = os.path.join(checkpoint_dir, "gptq_checkpoint_meta.json") - with open(meta_path, "w") as f: - json.dump(meta, f, indent=2) - print_rank_0(f"GPTQ checkpoint saved (layer {last_layer_idx}/{total_layers - 1})") - - @torch.no_grad() def sequential_calibrate( model: nn.Module, forward_loop: ForwardLoop, calib_func: Callable, - checkpoint_every_n_layers: int | None = None, - checkpoint_dir: str | None = None, - resume_from_layer: int = 0, **calib_kwargs, ): """Sequential calibration - a sequential layer-by-layer calibration algorithm. @@ -2093,14 +1908,14 @@ def _layer_forward_loop(m, _inputs=layer_inputs): def gptq( layer: nn.Module, inputs: list[tuple[tuple, dict]], + forward_loop: ForwardLoop, percdamp: float = 0.01, block_size: int = 128, **kwargs, ): """GPTQ quantization - a GPTQ variant.""" - import time - - total_start = time.time() + # Set weight amax and activation amax'es for the current layer using max_calibrate + max_calibrate(layer, forward_loop=forward_loop) # Dictionary to store hessian matrices for all linear layers in this decoder hessian_state = {} @@ -2146,7 +1961,6 @@ def hessian_forward(self, input, *args, **kwargs): patched_modules.append(module) # Run forward passes with the provided inputs to collect Hessians - hessian_start = time.time() print_rank_0( f"Computing Hessians for {len(tensor_mapping)} linear layers using {len(inputs)} batches..." ) @@ -2157,11 +1971,8 @@ def hessian_forward(self, input, *args, **kwargs): for module in patched_modules: unpatch_forward_method(module, "_forward_no_gptq_hessian") - torch.cuda.synchronize() if torch.cuda.is_available() else None - hessian_time = time.time() - hessian_start # Phase 3: Update weights using computed Hessians (same as gptq_lite) - weight_update_start = time.time() print_rank_0("Updating weights using GPTQ algorithm...") for name, module in layer.named_modules(): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: @@ -2173,13 +1984,3 @@ def hessian_forward(self, input, *args, **kwargs): # Free memory del hessian_state[module.name] torch.cuda.empty_cache() - - torch.cuda.synchronize() if torch.cuda.is_available() else None - weight_update_time = time.time() - weight_update_start - - total_time = time.time() - total_start - print_rank_0( - f"GPTQ timing - Hessian: {hessian_time:.2f}s, " - f"Weight update: {weight_update_time:.2f}s, " - f"Total: {total_time:.2f}s" - ) From 7ebd5ef8f9c920639029c32ee859e5a944637243 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 9 Feb 2026 16:46:47 +0000 Subject: [PATCH 18/52] tested, revert later Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 76 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 5905829a23..167d95b0d2 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -741,6 +741,82 @@ def export_quantized( "Unified HF export format does not specify inference tensor parallel or pipeline parallel. " "They will be set at deployment time." ) + if True: + # Disable quantizers + # mtq.fold_weight(full_model) + # print("Folded weights") + print("Disabling quantizers for perplexity evaluation (weights are already QDQ'ed)") + mtq.disable_quantizer(full_model, "*") + if True: + # mtq.fold_weight(full_model) + import os + + import torch.nn.functional as F + from datasets import load_dataset + from tqdm import trange + from transformers import AutoTokenizer + + # Set cache directory to work directory to avoid disk space issues + cache_dir = os.path.join( + os.path.dirname(os.path.abspath(__file__)), ".hf_cache" + ) + os.makedirs(cache_dir, exist_ok=True) + os.environ["HF_DATASETS_CACHE"] = cache_dir + print(f"Using HuggingFace datasets cache: {cache_dir}") + + def _get_wikitext2(tokenizer: AutoTokenizer, sequence_length: int): + test_dataset_raw = load_dataset( + "wikitext", "wikitext-2-raw-v1", split="test", cache_dir=cache_dir + ) + test_dataset_tok = tokenizer( + "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" + ).input_ids + num_test_sequences = test_dataset_tok.numel() // sequence_length + test_loader = [ + test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length] + for i in range(num_test_sequences) + ] + return test_loader + + @torch.no_grad() + def _compute_perplexity(model, data, batch_size: int = 1): + num_samples = len(data) + device = next(model.parameters()).device + # Running estimate of negative log-likelihood + nll_running = 0 + # Number of tokens processed to far + tokens_processed = 0 + # Loop through each batch + for i in trange( + 0, num_samples, batch_size, desc="Computing perplexity", leave=False + ): + j = min(i + batch_size, num_samples) + inputs = torch.cat(data[i:j]).to(device) + # Forward pass through the model + lm_logits = model(inputs).logits + # Shift logits and labels for next token prediction + shift_logits = lm_logits[:, :-1, :].contiguous() + shift_labels = inputs[:, 1:] + # Compute loss + loss = F.cross_entropy( + shift_logits.reshape(-1, shift_logits.size(-1)), + shift_labels.reshape(-1), + ) + # Calculate negative log likelihood + a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) + b = tokens_processed / (tokens_processed + shift_labels.numel()) + nll_running = a * loss + b * nll_running + # Update number of processed tokens + tokens_processed += shift_labels.numel() + # Compute perplexity + ppl = nll_running.exp().item() + return ppl + + eval_data = _get_wikitext2(tokenizer, 2048) + ppl = _compute_perplexity(full_model, eval_data) + print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") + + breakpoint() if True: import os From 26ac17445d0c218a8112757d15e246c7da3fc8b6 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 10 Feb 2026 04:41:46 +0000 Subject: [PATCH 19/52] tested Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 220 -------------------------- modelopt/torch/quantization/config.py | 94 +++++++++++ 2 files changed, 94 insertions(+), 220 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 167d95b0d2..877d1a9d29 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -741,226 +741,6 @@ def export_quantized( "Unified HF export format does not specify inference tensor parallel or pipeline parallel. " "They will be set at deployment time." ) - if True: - # Disable quantizers - # mtq.fold_weight(full_model) - # print("Folded weights") - print("Disabling quantizers for perplexity evaluation (weights are already QDQ'ed)") - mtq.disable_quantizer(full_model, "*") - if True: - # mtq.fold_weight(full_model) - import os - - import torch.nn.functional as F - from datasets import load_dataset - from tqdm import trange - from transformers import AutoTokenizer - - # Set cache directory to work directory to avoid disk space issues - cache_dir = os.path.join( - os.path.dirname(os.path.abspath(__file__)), ".hf_cache" - ) - os.makedirs(cache_dir, exist_ok=True) - os.environ["HF_DATASETS_CACHE"] = cache_dir - print(f"Using HuggingFace datasets cache: {cache_dir}") - - def _get_wikitext2(tokenizer: AutoTokenizer, sequence_length: int): - test_dataset_raw = load_dataset( - "wikitext", "wikitext-2-raw-v1", split="test", cache_dir=cache_dir - ) - test_dataset_tok = tokenizer( - "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" - ).input_ids - num_test_sequences = test_dataset_tok.numel() // sequence_length - test_loader = [ - test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length] - for i in range(num_test_sequences) - ] - return test_loader - - @torch.no_grad() - def _compute_perplexity(model, data, batch_size: int = 1): - num_samples = len(data) - device = next(model.parameters()).device - # Running estimate of negative log-likelihood - nll_running = 0 - # Number of tokens processed to far - tokens_processed = 0 - # Loop through each batch - for i in trange( - 0, num_samples, batch_size, desc="Computing perplexity", leave=False - ): - j = min(i + batch_size, num_samples) - inputs = torch.cat(data[i:j]).to(device) - # Forward pass through the model - lm_logits = model(inputs).logits - # Shift logits and labels for next token prediction - shift_logits = lm_logits[:, :-1, :].contiguous() - shift_labels = inputs[:, 1:] - # Compute loss - loss = F.cross_entropy( - shift_logits.reshape(-1, shift_logits.size(-1)), - shift_labels.reshape(-1), - ) - # Calculate negative log likelihood - a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) - b = tokens_processed / (tokens_processed + shift_labels.numel()) - nll_running = a * loss + b * nll_running - # Update number of processed tokens - tokens_processed += shift_labels.numel() - # Compute perplexity - ppl = nll_running.exp().item() - return ppl - - eval_data = _get_wikitext2(tokenizer, 2048) - ppl = _compute_perplexity(full_model, eval_data) - print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") - - breakpoint() - - if True: - import os - - import torch.nn.functional as F - from datasets import load_dataset - from tqdm import trange - from transformers import AutoTokenizer - - # Set cache directory to work directory to avoid disk space issues - cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".hf_cache") - os.makedirs(cache_dir, exist_ok=True) - os.environ["HF_DATASETS_CACHE"] = cache_dir - print(f"Using HuggingFace datasets cache: {cache_dir}") - - def _get_wikitext2(tokenizer: AutoTokenizer, sequence_length: int): - test_dataset_raw = load_dataset( - "wikitext", "wikitext-2-raw-v1", split="test", cache_dir=cache_dir - ) - test_dataset_tok = tokenizer( - "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" - ).input_ids - num_test_sequences = test_dataset_tok.numel() // sequence_length - test_loader = [ - test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length] - for i in range(num_test_sequences) - ] - return test_loader - - @torch.no_grad() - def _compute_perplexity(model, data, batch_size: int = 1): - num_samples = len(data) - device = next(model.parameters()).device - # Running estimate of negative log-likelihood - nll_running = 0 - # Number of tokens processed to far - tokens_processed = 0 - # Loop through each batch - for i in trange( - 0, num_samples, batch_size, desc="Computing perplexity", leave=False - ): - j = min(i + batch_size, num_samples) - inputs = torch.cat(data[i:j]).to(device) - # Forward pass through the model - lm_logits = model(inputs).logits - # Shift logits and labels for next token prediction - shift_logits = lm_logits[:, :-1, :].contiguous() - shift_labels = inputs[:, 1:] - # Compute loss - loss = F.cross_entropy( - shift_logits.reshape(-1, shift_logits.size(-1)), - shift_labels.reshape(-1), - ) - # Calculate negative log likelihood - a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) - b = tokens_processed / (tokens_processed + shift_labels.numel()) - nll_running = a * loss + b * nll_running - # Update number of processed tokens - tokens_processed += shift_labels.numel() - # Compute perplexity - ppl = nll_running.exp().item() - return ppl - - eval_data = _get_wikitext2(tokenizer, 2048) - ppl = _compute_perplexity(full_model, eval_data) - print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") - breakpoint() - - if args.export_qdq_weights: - # Disable quantizers - if "gptq" not in args.qformat: - mtq.fold_weight(full_model) - print("Folded weights") - - print(f"Saving model to {args.export_path}") - full_model.save_pretrained(args.export_path) - - if True: - import os - - import torch.nn.functional as F - from datasets import load_dataset - from tqdm import trange - from transformers import AutoTokenizer - - # Set cache directory to work directory to avoid disk space issues - cache_dir = os.path.join( - os.path.dirname(os.path.abspath(__file__)), ".hf_cache" - ) - os.makedirs(cache_dir, exist_ok=True) - os.environ["HF_DATASETS_CACHE"] = cache_dir - print(f"Using HuggingFace datasets cache: {cache_dir}") - - def _get_wikitext2(tokenizer: AutoTokenizer, sequence_length: int): - test_dataset_raw = load_dataset( - "wikitext", "wikitext-2-raw-v1", split="test", cache_dir=cache_dir - ) - test_dataset_tok = tokenizer( - "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" - ).input_ids - num_test_sequences = test_dataset_tok.numel() // sequence_length - test_loader = [ - test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length] - for i in range(num_test_sequences) - ] - return test_loader - - @torch.no_grad() - def _compute_perplexity(model, data, batch_size: int = 1): - num_samples = len(data) - device = next(model.parameters()).device - # Running estimate of negative log-likelihood - nll_running = 0 - # Number of tokens processed to far - tokens_processed = 0 - # Loop through each batch - for i in trange( - 0, num_samples, batch_size, desc="Computing perplexity", leave=False - ): - j = min(i + batch_size, num_samples) - inputs = torch.cat(data[i:j]).to(device) - # Forward pass through the model - lm_logits = model(inputs).logits - # Shift logits and labels for next token prediction - shift_logits = lm_logits[:, :-1, :].contiguous() - shift_labels = inputs[:, 1:] - # Compute loss - loss = F.cross_entropy( - shift_logits.reshape(-1, shift_logits.size(-1)), - shift_labels.reshape(-1), - ) - # Calculate negative log likelihood - a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) - b = tokens_processed / (tokens_processed + shift_labels.numel()) - nll_running = a * loss + b * nll_running - # Update number of processed tokens - tokens_processed += shift_labels.numel() - # Compute perplexity - ppl = nll_running.exp().item() - return ppl - - eval_data = _get_wikitext2(tokenizer, 2048) - ppl = _compute_perplexity(full_model, eval_data) - print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 09a2bf559f..a186b073ae 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -434,6 +434,100 @@ def find_quant_cfg_entry_by_path( }, } +NVFP4_STATIC_WO_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq", + "use_sequential": True, + }, +} + +NVFP4_STATIC_WO_GPTQ_LITE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} + +NVFP4_STATIC_WO_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "max", + "use_sequential": False, + }, +} + +NVFP4_STATIC_WO_GPTQ_LITE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} + +NVFP4_DYNAMIC_WO_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "gptq_lite", + "use_sequential": False, + }, +} INT4_AWQ_CFG = { "quant_cfg": [ From 868c7d66eeccea12cd61c0ddd44bb3b325ec273b Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 24 Feb 2026 00:01:34 +0000 Subject: [PATCH 20/52] initial cleanup Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 53 ------------------ modelopt/torch/export/quant_utils.py | 62 ++++++---------------- modelopt/torch/export/unified_export_hf.py | 11 ++-- modelopt/torch/quantization/__init__.py | 1 - modelopt/torch/quantization/model_calib.py | 4 -- 5 files changed, 20 insertions(+), 111 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 877d1a9d29..aac9f8ccbc 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -584,43 +584,6 @@ def mono_quantize( else: calibrate_loop = create_forward_loop(dataloader=calib_dataloader) - # Phase 1: Collect pre-quantization activations (batch_size=1 to save memory) - if getattr(args, "measure_activation_mse", False): - mse_max_samples = getattr(args, "activation_mse_max_samples", 16) - mse_save_dir = getattr(args, "activation_mse_save_dir", None) - mse_input_path = getattr(args, "activation_mse_input_path", None) - - # Materialize or load a frozen set of MSE inputs so that the exact - # same samples are used across runs and across codebases. - if mse_input_path and os.path.isfile(mse_input_path): - mse_data = mtq.ActivationMSELogger.load_data(mse_input_path) - else: - from torch.utils.data import DataLoader as _DataLoader - - mse_dataloader = _DataLoader(calib_dataloader.dataset, batch_size=1, shuffle=False) - if mse_input_path: - mse_data = mtq.ActivationMSELogger.materialize_data( - mse_dataloader, - mse_input_path, - max_samples=mse_max_samples, - ) - else: - # No path given -- materialize in memory only - mse_data = [] - for i, batch in enumerate(mse_dataloader): - if i >= mse_max_samples: - break - t = batch["input_ids"] if isinstance(batch, dict) else batch - mse_data.append(t.cpu()) - - mse_logger = mtq.ActivationMSELogger( - max_samples=mse_max_samples, - layer_filter=getattr(args, "activation_mse_layer_filter", None), - save_dir=mse_save_dir, - ) - print("\n--- Phase 1: Collecting pre-quantization activations ---") - mse_logger.collect(language_model, mse_data, phase="original") - if calibration_only: language_model = mtq.calibrate( language_model, quant_cfg["algorithm"], forward_loop=calibrate_loop @@ -628,16 +591,6 @@ def mono_quantize( else: language_model = mtq.quantize(language_model, quant_cfg, forward_loop=calibrate_loop) - # Phase 2: Compute MSE against stored pre-quant activations - if getattr(args, "measure_activation_mse", False): - print("\n--- Phase 2: Computing per-layer activation MSE ---") - mse_logger.collect(language_model, mse_data, phase="quantized") - mse_logger.compute_mse() - print(mse_logger.summary()) - if mse_save_dir: - mse_logger.save() - del mse_logger, mse_data - # For VL models, update full_model to use the quantized language model if is_nemotron_vl_model: language_model_lineage = get_language_model_from_vl(full_model) @@ -1218,12 +1171,6 @@ def parse_args() -> argparse.Namespace: default=False, action="store_true", ) - parser.add_argument( - "--export_qdq_weights", - help=("Used for GPTQ weights as is without compressed weights for deployment."), - default=False, - action="store_true", - ) parser.add_argument( "--verbose", help="Print verbose output (e.g. quantization summary). Disable by --no-verbose.", diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index b762757cb9..674d0596e3 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -46,7 +46,7 @@ ) from modelopt.torch.utils import clear_cuda_cache -from ..quantization.nn import NVFP4StaticQuantizer, SequentialQuantizer, TensorQuantizer +from ..quantization.nn import SequentialQuantizer, TensorQuantizer from .model_config import ( KV_CACHE_FP8, KV_CACHE_INT8, @@ -353,17 +353,15 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> return get_scaling_factor(weight_quantizer[0]) quantization_format = get_quantization_format(module) - if quantization_format in [ QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT, QUANTIZATION_W4A8_NVFP4_FP8, ]: - # Calibrate weight quantizer if amax is not set (only needed for dynamic quantizers) - if not is_nvfp4_static: - module_name = f"{type(module).__name__}.{weight_name}" - _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) + # Calibrate weight quantizer if amax is not set + module_name = f"{type(module).__name__}.{weight_name}" + _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. @@ -373,10 +371,9 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> weight_scaling_factor_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer( weight_quantizer ) - # Unified method handles both static and dynamic quantizers - return NVFP4QTensor.get_weights_scaling_factor_from_quantizer( - weight_quantizer, + return NVFP4QTensor.get_weights_scaling_factor( weight, + weight_quantizer.block_sizes[-1], weight_scaling_factor_2.to(weight.device), )[0] @@ -410,13 +407,16 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") module_name = f"{type(module).__name__}.{weight_name}" _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) - if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: - # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. - # This is because the kernel dequantizes weight to fp8, which is in range 448. - return weight_quantizer._amax.float() / 448.0 - else: - # Unified method handles both static and dynamic quantizers - return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) + if quantization_format in [ + QUANTIZATION_NVFP4, + QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, + ]: + return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) + elif quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: + # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. + # This is because the kernel dequantizes weight to fp8, which is in range 448. + return weight_quantizer._amax.float() / 448.0 # SequentialQuantizer is required if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled: @@ -799,7 +799,7 @@ def process_layer_quant_config(layer_config_dict): layer_config = {"quant_algo": "W8A16"} elif v == "int8_sq": layer_config = {"quant_algo": "W8A8_SQ_PER_CHANNEL"} - elif v in ["nvfp4", "nvfp4_static"]: + elif v == "nvfp4": layer_config = { "quant_algo": "NVFP4", "group_size": block_size_value, @@ -1397,18 +1397,6 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False for module in modules: module.weight_quantizer[-1].amax = weight_amax - # Handle NVFP4StaticQuantizer: unify global_amax for fused layers - elif isinstance(modules[0].weight_quantizer, NVFP4StaticQuantizer): - global_amax_list = [ - m.weight_quantizer.global_amax - for m in modules - if m.weight_quantizer.global_amax is not None - ] - if global_amax_list: - unified_global_amax = torch.max(torch.stack(global_amax_list)) - for module in modules: - module.weight_quantizer.global_amax = unified_global_amax - elif ( modules[0].weight_quantizer.is_enabled and modules[0].weight_quantizer.amax is not None @@ -1493,22 +1481,6 @@ def get_quant_config( if block_size == 0: block_size = get_weight_block_size(module) - # Static NVFP4 uses pre-computed per-block scales from MSE calibration - if quantization_format == QUANTIZATION_NVFP4: - weight_quantizer = getattr(module, "weight_quantizer", None) - if weight_quantizer is None: - # Try to get from first weight attribute - for wn in weight_names: - weight_quantizer = getattr( - module, quantizer_attr_names(wn).weight_quantizer, None - ) - if weight_quantizer is not None: - break - if weight_quantizer is not None: - is_static = isinstance(weight_quantizer, NVFP4StaticQuantizer) - if is_static: - quantization_format = "nvfp4_static" - # Construct per layer config dictionary layer_config_dict[name + ".quantization"] = quantization_format layer_config_dict[name + ".awq_block_size"] = block_size diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 8a542b580d..04775715b4 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -52,11 +52,7 @@ from torch.distributed.fsdp import FSDPModule from modelopt.torch.quantization import set_quantizer_by_cfg_context -from modelopt.torch.quantization.nn import ( - NVFP4StaticQuantizer, - SequentialQuantizer, - TensorQuantizer, -) +from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names @@ -551,7 +547,6 @@ def _export_quantized_weight( weight, _ = maybe_transpose_expert_weight_dimensions( weight, is_bmm_expert_weight=is_bmm_expert_weight ) - weight_scale = NVFP4QTensor.get_weights_scaling_factor( weight, block_size=block_size, @@ -559,7 +554,7 @@ def _export_quantized_weight( )[0] quantized_weight = to_quantized_weight( - weight.to(torch.bfloat16), + weight.to(dtype), weight_scale, quantization_format, weight_scale_2, @@ -576,7 +571,7 @@ def _export_quantized_weight( ) quantized_weight = to_quantized_weight( - weight.to(torch.bfloat16), + weight.to(dtype), weight_scale, quantization_format, weight_scale_2, diff --git a/modelopt/torch/quantization/__init__.py b/modelopt/torch/quantization/__init__.py index 757b844fb1..87dbf30bb5 100644 --- a/modelopt/torch/quantization/__init__.py +++ b/modelopt/torch/quantization/__init__.py @@ -19,7 +19,6 @@ from . import mode, plugins, utils # Add methods to mtq namespace -from .activation_mse import ActivationMSELogger, collect_activations, measure_activation_mse from .compress import * from .config import * from .conversion import * diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 24f5a19cd4..d873939543 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -15,9 +15,6 @@ """Calibration utilities.""" -import contextlib -import datetime -import json import math import os import warnings @@ -1971,7 +1968,6 @@ def hessian_forward(self, input, *args, **kwargs): for module in patched_modules: unpatch_forward_method(module, "_forward_no_gptq_hessian") - # Phase 3: Update weights using computed Hessians (same as gptq_lite) print_rank_0("Updating weights using GPTQ algorithm...") for name, module in layer.named_modules(): From 6f8870d70d3567b15cc69fb527113d8651edc8a8 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 24 Feb 2026 00:24:55 +0000 Subject: [PATCH 21/52] cleanup Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/activation_mse.py | 787 ------------------ modelopt/torch/quantization/config.py | 154 ---- 2 files changed, 941 deletions(-) delete mode 100644 modelopt/torch/quantization/activation_mse.py diff --git a/modelopt/torch/quantization/activation_mse.py b/modelopt/torch/quantization/activation_mse.py deleted file mode 100644 index df90c84a3a..0000000000 --- a/modelopt/torch/quantization/activation_mse.py +++ /dev/null @@ -1,787 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Per-layer activation MSE measurement for quantization analysis. - -This module provides utilities to measure per-linear-layer MSE between a model's -activations before and after quantization. Inspired by FP-Quant's two-phase approach: - -- **Phase 1** (before quantization): ``collect_activations()`` runs the model on - calibration data and stores per-layer outputs in CPU RAM. -- **Phase 2** (after quantization): ``measure_activation_mse()`` runs the quantized - model on the same data and computes MSE on-the-fly against the stored Phase 1 - outputs. Only running scalar accumulators are kept -- no second set of tensors - is stored. - -Typical usage in hf_ptq.py:: - - # Phase 1: before quantization - orig_acts = mtq.collect_activations(model, mse_dataloader, max_samples=16) - - # Quantize - model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - - # Phase 2: after quantization -- computes MSE incrementally - mse = mtq.measure_activation_mse(model, mse_dataloader, orig_acts, max_samples=16) -""" - -import contextlib -import fnmatch -import hashlib -import os -from collections.abc import Iterable -from datetime import datetime - -import torch -import torch.nn as nn -import torch.nn.functional as F -from tqdm import tqdm - -from modelopt.torch.utils.network import get_decoder_layers - -__all__ = ["ActivationMSELogger", "collect_activations", "measure_activation_mse"] - - -def _tensor_from_output(out) -> torch.Tensor: - """Extract a single tensor from a layer's output (handles tuple returns).""" - if isinstance(out, torch.Tensor): - return out.detach() - return out[0].detach() - - -def _is_linear(module: nn.Module) -> bool: - """Check if a module is a linear layer (covers both nn.Linear and quantized linear).""" - return isinstance(module, nn.Linear) - - -def _matches_filter(name: str, layer_filter: str | None) -> bool: - """Check if a layer name matches the optional filter pattern (fnmatch-style).""" - if layer_filter is None: - return True - return fnmatch.fnmatch(name, layer_filter) - - -def _discover_target_layers( - model: nn.Module, - layer_filter: str | None = None, -) -> dict[str, nn.Module]: - """Discover linear layers within decoder blocks of the model. - - Uses get_decoder_layers() to find transformer blocks, then finds all linear - submodules within those blocks. Falls back to all linear layers in the model - if decoder blocks cannot be identified. - - Args: - model: The model to inspect. - layer_filter: Optional fnmatch pattern to select specific layers - (e.g., ``"*self_attn*"``). - - Returns: - Dict mapping full module path -> module reference. - """ - decoder_layers = get_decoder_layers(model) - - targets: dict[str, nn.Module] = {} - - if decoder_layers is not None: - # Build a reverse lookup: module id -> full name in model - module_to_name: dict[int, str] = {id(m): n for n, m in model.named_modules()} - - for block in decoder_layers: - block_name = module_to_name.get(id(block), "") - for sub_name, sub_mod in block.named_modules(): - if _is_linear(sub_mod): - full_name = f"{block_name}.{sub_name}" if block_name else sub_name - if _matches_filter(full_name, layer_filter): - targets[full_name] = sub_mod - else: - # Fallback: scan all modules - for name, module in model.named_modules(): - if _is_linear(module): - if _matches_filter(name, layer_filter): - targets[name] = module - - return targets - - -def _run_batch(model: nn.Module, batch) -> None: - """Run a single batch through the model.""" - if isinstance(batch, dict): - model(**batch) - elif isinstance(batch, (list, tuple)): - model(*batch) - else: - model(batch) - - -@torch.no_grad() -def collect_activations( - model: nn.Module, - dataloader: Iterable, - max_samples: int | None = None, - layer_filter: str | None = None, -) -> dict[str, list[torch.Tensor]]: - """Collect per-linear-layer output activations into CPU memory (Phase 1). - - Registers forward hooks on linear layers within the model's decoder blocks, - runs calibration data through the model, and returns captured per-layer outputs. - - Args: - model: The model to collect activations from (typically pre-quantization). - dataloader: An iterable yielding batches (dicts with ``input_ids``, etc.). - Use batch_size=1 to minimize memory. - max_samples: Maximum number of batches to process. ``None`` means all. - layer_filter: Optional fnmatch pattern to restrict which layers are - collected (e.g., ``"*self_attn*"``). ``None`` means all linear layers - inside decoder blocks. - - Returns: - Dict mapping layer name to a list of output tensors (one per batch, on CPU). - """ - was_training = model.training - model.eval() - - # Discover target linear layers - targets = _discover_target_layers(model, layer_filter) - if not targets: - raise ValueError( - f"No linear layers found matching the given filter. layer_filter={layer_filter!r}" - ) - - print(f"Collecting activations for {len(targets)} layers...") - - # Storage: {layer_name: [tensor_per_batch, ...]} - saved: dict[str, list[torch.Tensor]] = {name: [] for name in targets} - captured: dict[str, torch.Tensor] = {} - - def _make_hook(key: str): - def hook(_module: nn.Module, _input, output) -> None: - captured[key] = _tensor_from_output(output).cpu() - - return hook - - # Register hooks - hooks = [] - for name, module in targets.items(): - hooks.append(module.register_forward_hook(_make_hook(name))) - - try: - n_batches = 0 - for batch in tqdm(dataloader, desc="Collecting activations", leave=False): - if max_samples is not None and n_batches >= max_samples: - break - - captured.clear() - _run_batch(model, batch) - - for name in targets: - if name in captured: - saved[name].append(captured[name]) - - n_batches += 1 - finally: - for h in hooks: - h.remove() - - model.train(was_training) - - print(f"Collected {n_batches} samples across {len(targets)} layers") - return saved - - -@torch.no_grad() -def measure_activation_mse( - model: nn.Module, - dataloader: Iterable, - orig_activations: dict[str, list[torch.Tensor]], - max_samples: int | None = None, - layer_filter: str | None = None, -) -> dict[str, float]: - """Compute per-layer MSE between stored and live activations (Phase 2). - - Runs the (quantized) model on calibration data and computes MSE on-the-fly - against the pre-quantization activations stored by :func:`collect_activations`. - - Only scalar accumulators (sum of squared errors and element count) are kept - per layer -- no second set of activation tensors is stored. - - The MSE for each layer is computed as:: - - MSE = sum_over_all_elements((orig - quant) ^ 2) / total_elements - - Args: - model: The quantized model to measure. - dataloader: Same dataloader used for :func:`collect_activations` - (must yield batches in the same order). - orig_activations: Output of :func:`collect_activations` -- dict mapping - layer name to a list of pre-quantization output tensors. - max_samples: Maximum number of batches to process (should match Phase 1). - layer_filter: Optional fnmatch pattern (should match Phase 1). - - Returns: - Dict mapping layer name to its MSE value. - """ - was_training = model.training - model.eval() - - # Discover target layers on the (now-quantized) model - targets = _discover_target_layers(model, layer_filter) - - # Only measure layers that exist in both the model and orig_activations - common_keys = sorted(set(targets.keys()) & set(orig_activations.keys())) - if not common_keys: - raise ValueError( - "No matching layers between the quantized model and stored activations. " - "Ensure the same layer_filter is used for both phases." - ) - - skipped = set(orig_activations.keys()) - set(targets.keys()) - if skipped: - print(f"Warning: {len(skipped)} layers in orig_activations not found in model (skipped)") - - print(f"Computing activation MSE for {len(common_keys)} layers...") - - # Scalar accumulators - sum_sq: dict[str, float] = dict.fromkeys(common_keys, 0.0) - count: dict[str, int] = dict.fromkeys(common_keys, 0) - - captured: dict[str, torch.Tensor] = {} - - def _make_hook(key: str): - def hook(_module: nn.Module, _input, output) -> None: - captured[key] = _tensor_from_output(output).cpu() - - return hook - - # Register hooks only on common layers - hooks = [targets[name].register_forward_hook(_make_hook(name)) for name in common_keys] - - try: - batch_idx = 0 - for batch in tqdm(dataloader, desc="Computing activation MSE", leave=False): - if max_samples is not None and batch_idx >= max_samples: - break - - captured.clear() - _run_batch(model, batch) - - for name in common_keys: - if name not in captured: - continue - if batch_idx >= len(orig_activations.get(name, [])): - continue - - o = orig_activations[name][batch_idx].float() - q = captured[name].float() - - if o.shape != q.shape: - print( - f"Warning: shape mismatch for {name} batch {batch_idx}: " - f"{o.shape} vs {q.shape}, skipping" - ) - continue - - sum_sq[name] += F.mse_loss(o, q, reduction="sum").item() - count[name] += o.numel() - - batch_idx += 1 - finally: - for h in hooks: - h.remove() - - model.train(was_training) - - mse = { - key: (sum_sq[key] / count[key]) if count[key] > 0 else float("nan") for key in common_keys - } - - return mse - - -# --------------------------------------------------------------------------- -# Portable ActivationMSELogger class -# --------------------------------------------------------------------------- - - -def _portable_discover_target_layers( - model: nn.Module, - layer_filter: str | None = None, -) -> dict[str, nn.Module]: - """Discover linear layers in decoder blocks with a portable fallback chain. - - Strategy: - 1. Try modelopt's ``get_decoder_layers`` (available inside ModelOpt). - 2. Try common HuggingFace attribute paths (``model.model.layers``, etc.). - 3. Fall back to scanning **all** ``nn.Linear`` in ``model.named_modules()``. - - Within each set of decoder blocks the function collects every ``nn.Linear`` - sub-module and optionally filters by *layer_filter* (fnmatch pattern). - """ - decoder_layers = None - - # 1. Try modelopt helper (may not exist when file is copied elsewhere) - with contextlib.suppress(Exception): - decoder_layers = get_decoder_layers(model) - - # 2. Try common HF / other patterns - if decoder_layers is None: - for attr_chain in ( - ("model", "layers"), - ("decoder", "layers"), - ("transformer", "h"), - ("backbone", "layers"), - ): - obj = model - try: - for attr in attr_chain: - obj = getattr(obj, attr) - if isinstance(obj, nn.ModuleList): - decoder_layers = obj - break - except AttributeError: - continue - - targets: dict[str, nn.Module] = {} - - if decoder_layers is not None: - module_to_name: dict[int, str] = {id(m): n for n, m in model.named_modules()} - for block in decoder_layers: - block_name = module_to_name.get(id(block), "") - for sub_name, sub_mod in block.named_modules(): - if isinstance(sub_mod, nn.Linear): - full_name = f"{block_name}.{sub_name}" if block_name else sub_name - if _matches_filter(full_name, layer_filter): - targets[full_name] = sub_mod - else: - # 3. Fallback: all linear layers - for name, module in model.named_modules(): - if isinstance(module, nn.Linear): - if _matches_filter(name, layer_filter): - targets[name] = module - - return targets - - -class ActivationMSELogger: - """Portable activation MSE logger for comparing original vs quantized models. - - Works with both: - - - ``List[Tensor]`` data (**FP-Quant** style): each element is ``[1, seq_len]`` - or ``[B, seq_len]``, consumed via ``model(tensor)``. - - ``DataLoader`` / ``Iterable`` yielding dicts (**ModelOpt** style): - ``{"input_ids": tensor, ...}``, consumed via ``model(**batch)``. - - Guarantees same samples are used for both phases via SHA-256 hashing of - input tensors. Supports saving / loading all activations to disk for - later cross-codebase comparison. - - Example (ModelOpt -- DataLoader with dict batches):: - - mse_logger = ActivationMSELogger(max_samples=16, save_dir="./mse_logs") - mse_logger.collect(model, dataloader, phase="original") - model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - mse_logger.collect(model, dataloader, phase="quantized") - results = mse_logger.compute_mse() - print(mse_logger.summary()) - mse_logger.save() - - Example (FP-Quant -- List[Tensor]):: - - mse_logger = ActivationMSELogger(max_samples=16, save_dir="./mse_logs") - mse_logger.collect(model_orig, calibration_data, phase="original") - mse_logger.collect(model_quant, calibration_data, phase="quantized") - results = mse_logger.compute_mse() - print(mse_logger.summary()) - mse_logger.save() - """ - - def __init__( - self, - max_samples: int = 16, - layer_filter: str | None = None, - save_dir: str | None = None, - ): - """Initialize the ActivationMSELogger. - - Args: - max_samples: Maximum number of calibration batches to process per phase. - layer_filter: Optional glob pattern to restrict which layers are tracked. - save_dir: Optional directory path for persisting activation data to disk. - """ - self.max_samples = max_samples - self.layer_filter = layer_filter - self.save_dir = save_dir - - # Per-phase state - self.original_activations: dict[str, list[torch.Tensor]] = {} - self.quantized_activations: dict[str, list[torch.Tensor]] = {} - self.input_hashes: list[str] = [] # hashes for "original" phase - self.quant_input_hashes: list[str] = [] # hashes for "quantized" phase - - # Computed after both phases - self.mse_results: dict[str, float] | None = None - - # ------------------------------------------------------------------ - # Public API - # ------------------------------------------------------------------ - - @torch.no_grad() - def collect( - self, - model: nn.Module, - data: Iterable, - phase: str, - target_modules: dict[str, nn.Module] | None = None, - ) -> None: - """Collect per-linear-layer output activations for a given phase. - - Args: - model: The model to run (original or quantized). - data: An iterable of batches. Each batch can be: - - - ``torch.Tensor`` with shape ``[B, seq_len]`` (FP-Quant style). - - ``dict`` with at least an ``"input_ids"`` key (ModelOpt style). - - ``list`` / ``tuple`` of tensors. - phase: ``"original"`` or ``"quantized"``. - target_modules: Optional explicit mapping of ``{name: nn.Module}`` - to attach hooks to. If *None*, layers are auto-discovered - via decoder-block scanning. - """ - if phase not in ("original", "quantized"): - raise ValueError(f"phase must be 'original' or 'quantized', got {phase!r}") - - was_training = model.training - model.eval() - - # ----- layer discovery ----- - targets = ( - target_modules - if target_modules is not None - else (_portable_discover_target_layers(model, self.layer_filter)) - ) - if not targets: - raise ValueError( - "No linear layers found. Provide target_modules explicitly or " - f"check layer_filter={self.layer_filter!r}." - ) - - print( - f"[ActivationMSELogger] Phase '{phase}': hooking {len(targets)} layers, " - f"max_samples={self.max_samples}" - ) - - # ----- storage ----- - saved: dict[str, list[torch.Tensor]] = {name: [] for name in targets} - captured: dict[str, torch.Tensor] = {} - hashes: list[str] = [] - - def _make_hook(key: str): - def hook(_module: nn.Module, _input, output) -> None: - captured[key] = _tensor_from_output(output).cpu() - - return hook - - hooks = [] - for name, module in targets.items(): - hooks.append(module.register_forward_hook(_make_hook(name))) - - try: - n_batches = 0 - for batch in tqdm(data, desc=f"Collecting ({phase})", leave=False): - if self.max_samples is not None and n_batches >= self.max_samples: - break - - captured.clear() - self._run_batch(model, batch) - - for name in targets: - if name in captured: - saved[name].append(captured[name]) - - hashes.append(self._hash_batch(batch)) - n_batches += 1 - finally: - for h in hooks: - h.remove() - - model.train(was_training) - - # ----- store results on self ----- - if phase == "original": - self.original_activations = saved - self.input_hashes = hashes - else: - self.quantized_activations = saved - self.quant_input_hashes = hashes - # Verify sample consistency - if self.input_hashes: - self._verify_hashes() - - # Invalidate any previous MSE since we have new activations - self.mse_results = None - - print(f"[ActivationMSELogger] Collected {n_batches} batches for phase '{phase}'") - - def compute_mse(self) -> dict[str, float]: - """Compute per-layer MSE between original and quantized activations. - - Returns: - Dict mapping layer name to its MSE value. - - Raises: - ValueError: If either phase has not been collected yet. - """ - if not self.original_activations: - raise ValueError( - "No original activations collected. Call collect(..., phase='original') first." - ) - if not self.quantized_activations: - raise ValueError( - "No quantized activations collected. Call collect(..., phase='quantized') first." - ) - - common_keys = sorted( - set(self.original_activations.keys()) & set(self.quantized_activations.keys()) - ) - if not common_keys: - raise ValueError( - "No matching layer names between original and quantized activations. " - "Ensure the same model architecture / layer_filter is used for both phases." - ) - - orig_only = set(self.original_activations.keys()) - set(self.quantized_activations.keys()) - quant_only = set(self.quantized_activations.keys()) - set(self.original_activations.keys()) - if orig_only: - print( - f"[ActivationMSELogger] Warning: {len(orig_only)} layers only in original (skipped)" - ) - if quant_only: - print( - f"[ActivationMSELogger] Warning: {len(quant_only)} layers only in quantized (skipped)" - ) - - sum_sq: dict[str, float] = dict.fromkeys(common_keys, 0.0) - count: dict[str, int] = dict.fromkeys(common_keys, 0) - - for name in common_keys: - orig_list = self.original_activations[name] - quant_list = self.quantized_activations[name] - n = min(len(orig_list), len(quant_list)) - for i in range(n): - o = orig_list[i].float() - q = quant_list[i].float() - if o.shape != q.shape: - print( - f"[ActivationMSELogger] Warning: shape mismatch for {name} " - f"batch {i}: {o.shape} vs {q.shape}, skipping" - ) - continue - sum_sq[name] += F.mse_loss(o, q, reduction="sum").item() - count[name] += o.numel() - - self.mse_results = { - key: (sum_sq[key] / count[key]) if count[key] > 0 else float("nan") - for key in common_keys - } - return self.mse_results - - def save(self, path: str | None = None) -> str: - """Save all state (activations, hashes, MSE) to disk via ``torch.save``. - - Args: - path: Explicit file path. If *None*, a timestamped file is created - inside ``self.save_dir`` (which must be set). - - Returns: - The path where the file was saved. - """ - if path is None: - if self.save_dir is None: - raise ValueError("Provide a path or set save_dir in the constructor.") - os.makedirs(self.save_dir, exist_ok=True) - ts = datetime.now().strftime("%Y%m%d_%H%M%S") - path = os.path.join(self.save_dir, f"activation_mse_{ts}.pt") - - payload = { - "max_samples": self.max_samples, - "layer_filter": self.layer_filter, - "input_hashes": self.input_hashes, - "quant_input_hashes": self.quant_input_hashes, - "original_activations": self.original_activations, - "quantized_activations": self.quantized_activations, - "mse": self.mse_results, - } - torch.save(payload, path) - print(f"[ActivationMSELogger] Saved to {path}") - return path - - @classmethod - def load(cls, path: str) -> "ActivationMSELogger": - """Load a previously saved ``ActivationMSELogger`` from disk. - - Args: - path: Path to the ``.pt`` file created by :meth:`save`. - - Returns: - A new ``ActivationMSELogger`` instance with restored state. - """ - payload = torch.load(path, map_location="cpu", weights_only=False) - logger = cls( - max_samples=payload.get("max_samples", 16), - layer_filter=payload.get("layer_filter"), - ) - logger.original_activations = payload.get("original_activations", {}) - logger.quantized_activations = payload.get("quantized_activations", {}) - logger.input_hashes = payload.get("input_hashes", []) - logger.quant_input_hashes = payload.get("quant_input_hashes", []) - logger.mse_results = payload.get("mse") - print(f"[ActivationMSELogger] Loaded from {path}") - return logger - - def summary(self) -> str: - """Return a formatted string summarising per-layer MSE results. - - Computes MSE first if not already done. - """ - if self.mse_results is None: - self.compute_mse() - assert self.mse_results is not None - - lines = ["Per-layer activation MSE (original vs quantized):"] - lines.extend( - f" {key}: {self.mse_results[key]:.6e}" for key in sorted(self.mse_results.keys()) - ) - return "\n".join(lines) - - # ------------------------------------------------------------------ - # Pre-materialized MSE data (cross-run / cross-codebase safety) - # ------------------------------------------------------------------ - - @staticmethod - def materialize_data( - data: Iterable, - path: str, - max_samples: int | None = None, - ) -> list[torch.Tensor]: - """Freeze the first *max_samples* batches from *data* into a ``.pt`` file. - - Each batch (``dict``, ``Tensor``, or ``list/tuple``) is normalised to a - single ``input_ids`` CPU tensor before saving. The resulting file is a - plain ``List[Tensor]`` that can be loaded in **any** codebase and passed - straight to :meth:`collect`. - - If *path* already exists it is **not** overwritten -- call - :meth:`load_data` instead. - - Args: - data: Iterable of batches (DataLoader, List[Tensor], etc.). - path: Destination ``.pt`` file path. - max_samples: How many batches to keep. ``None`` means all. - - Returns: - The materialised list of CPU tensors (same object that was saved). - """ - samples: list[torch.Tensor] = [] - for batch in data: - if max_samples is not None and len(samples) >= max_samples: - break - if isinstance(batch, dict): - t = batch.get("input_ids", next(iter(batch.values()))) - elif isinstance(batch, torch.Tensor): - t = batch - elif isinstance(batch, (list, tuple)): - t = batch[0] - else: - raise TypeError(f"Unsupported batch type: {type(batch)}") - samples.append(t.cpu()) - - os.makedirs(os.path.dirname(path) or ".", exist_ok=True) - torch.save(samples, path) - print(f"[ActivationMSELogger] Materialised {len(samples)} MSE input samples -> {path}") - return samples - - @staticmethod - def load_data(path: str) -> list[torch.Tensor]: - """Load a previously materialised MSE input set. - - Args: - path: Path to the ``.pt`` file created by :meth:`materialize_data`. - - Returns: - ``List[Tensor]`` of input batches (on CPU). - """ - samples = torch.load(path, map_location="cpu", weights_only=True) - print(f"[ActivationMSELogger] Loaded {len(samples)} MSE input samples from {path}") - return samples - - # ------------------------------------------------------------------ - # Static / private helpers - # ------------------------------------------------------------------ - - @staticmethod - def _run_batch(model: nn.Module, batch) -> None: - """Run a single batch through the model (handles Tensor, dict, list/tuple). - - Automatically moves inputs to the model's device so that CPU-stored - materialized data works transparently with a CUDA model. - """ - device = next(model.parameters()).device - if isinstance(batch, dict): - batch = { - k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() - } - model(**batch) - elif isinstance(batch, torch.Tensor): - model(batch.to(device)) - elif isinstance(batch, (list, tuple)): - batch = tuple(t.to(device) if isinstance(t, torch.Tensor) else t for t in batch) - model(*batch) - else: - raise TypeError(f"Unsupported batch type: {type(batch)}") - - @staticmethod - def _hash_batch(batch) -> str: - """Compute SHA-256 hash of the primary input tensor in *batch*. - - - ``dict`` -> hashes ``batch["input_ids"]`` (falls back to first value). - - ``Tensor`` -> hashes the tensor directly. - - ``list/tuple`` -> hashes the first element. - """ - if isinstance(batch, dict): - t = batch.get("input_ids", next(iter(batch.values()))) - elif isinstance(batch, torch.Tensor): - t = batch - elif isinstance(batch, (list, tuple)): - t = batch[0] if batch else None - else: - return "" - - if t is None or not isinstance(t, torch.Tensor): - return "" - return hashlib.sha256(t.cpu().contiguous().numpy().tobytes()).hexdigest() - - def _verify_hashes(self) -> None: - """Compare input hashes between original and quantized phases.""" - n = min(len(self.input_hashes), len(self.quant_input_hashes)) - mismatches = sum(1 for i in range(n) if self.input_hashes[i] != self.quant_input_hashes[i]) - if mismatches: - print( - f"[ActivationMSELogger] WARNING: {mismatches}/{n} batches have " - f"different input hashes between original and quantized phases. " - f"The same data may not have been used for both phases!" - ) - else: - print(f"[ActivationMSELogger] Input hash verification passed ({n}/{n} match)") diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index a186b073ae..c58e2d607a 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -242,53 +242,6 @@ def find_quant_cfg_entry_by_path( {"quantizer_name": "*o_proj*", "enable": False}, # Skip QKV Output Projection ] -SUPER_NVFP4_CONSERVATIVE_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - **_default_disabled_quantizer_cfg, - **super_disabled_quantizer_cfg, - "*mixer.in_proj*": {"enable": False}, # Skip mamba linear - "*mixer.out_proj*": {"enable": False}, # Skip mamba linear - }, - "algorithm": "max", -} - -SUPER_NVFP4_CONSERVATIVE_GPTQ_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - **_default_disabled_quantizer_cfg, - **super_disabled_quantizer_cfg, - "*mixer.in_proj*": {"enable": False}, # Skip mamba linear - "*mixer.out_proj*": {"enable": False}, # Skip mamba linear - }, - "algorithm": { - "method": "gptq", - "use_sequential": True, - }, -} - INT8_DEFAULT_CFG = { "quant_cfg": [ @@ -422,113 +375,6 @@ def find_quant_cfg_entry_by_path( "algorithm": "max", } -INT4_BLOCKWISE_WEIGHT_ONLY_GPTQ_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128}, "enable": True}, - "*input_quantizer": {"enable": False}, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq", - "use_sequential": True, - }, -} - -NVFP4_STATIC_WO_GPTQ_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq", - "use_sequential": True, - }, -} - -NVFP4_STATIC_WO_GPTQ_LITE_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq_lite", - "use_sequential": False, - }, -} - -NVFP4_STATIC_WO_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "max", - "use_sequential": False, - }, -} - -NVFP4_STATIC_WO_GPTQ_LITE_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq_lite", - "use_sequential": False, - }, -} - -NVFP4_DYNAMIC_WO_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "gptq_lite", - "use_sequential": False, - }, -} - INT4_AWQ_CFG = { "quant_cfg": [ *_base_disable_all, From 91f3d2dd04c50089a66ada67e1020bc09f98f50f Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 24 Feb 2026 00:31:18 +0000 Subject: [PATCH 22/52] removed stray prints Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index d873939543..ec59d113dc 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -358,7 +358,6 @@ def mse_calibrate( if fp8_scale_sweep and is_nvfp4_static: # Replace calibrator with NVFP4MSECalibrator - print("mse_calibrate: Replacing calibrator with NVFP4MSECalibrator") module._calibrator = NVFP4MSECalibrator( amax=initial_amax, axis=module._calibrator._axis, @@ -645,7 +644,6 @@ def quant_func(x, amax, quantizer=weight_quantizer): error_func = helper.get_error_func() if fp8_scale_sweep and is_nvfp4_static: - print("local_hessian_calibrate: Replacing calibrator with NVFP4MSECalibrator") weight_quantizer._calibrator = NVFP4MSECalibrator( amax=initial_amax, axis=weight_quantizer._calibrator._axis if weight_quantizer._calibrator else None, From cb46d7e3991d131dbfd8ad807af913f6c0e8ca82 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Fri, 6 Mar 2026 01:04:14 +0000 Subject: [PATCH 23/52] fix rebase issues Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/export/quant_utils.py | 55 ++++++++++++++----- modelopt/torch/export/unified_export_hf.py | 7 ++- modelopt/torch/quantization/mode.py | 2 - .../nn/modules/tensor_quantizer.py | 13 +---- .../torch/quantization/triton/__init__.py | 4 -- 5 files changed, 49 insertions(+), 32 deletions(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 674d0596e3..4ceb51cd2c 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -46,7 +46,7 @@ ) from modelopt.torch.utils import clear_cuda_cache -from ..quantization.nn import SequentialQuantizer, TensorQuantizer +from ..quantization.nn import NVFP4StaticQuantizer, SequentialQuantizer, TensorQuantizer from .model_config import ( KV_CACHE_FP8, KV_CACHE_INT8, @@ -353,6 +353,7 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> return get_scaling_factor(weight_quantizer[0]) quantization_format = get_quantization_format(module) + if quantization_format in [ QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, @@ -371,9 +372,10 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> weight_scaling_factor_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer( weight_quantizer ) - return NVFP4QTensor.get_weights_scaling_factor( + # Unified method handles both static and dynamic quantizers + return NVFP4QTensor.get_weights_scaling_factor_from_quantizer( + weight_quantizer, weight, - weight_quantizer.block_sizes[-1], weight_scaling_factor_2.to(weight.device), )[0] @@ -407,16 +409,13 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") module_name = f"{type(module).__name__}.{weight_name}" _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) - if quantization_format in [ - QUANTIZATION_NVFP4, - QUANTIZATION_NVFP4_AWQ, - QUANTIZATION_NVFP4_SVDQUANT, - ]: - return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) - elif quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: - # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. - # This is because the kernel dequantizes weight to fp8, which is in range 448. - return weight_quantizer._amax.float() / 448.0 + if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: + # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. + # This is because the kernel dequantizes weight to fp8, which is in range 448. + return weight_quantizer._amax.float() / 448.0 + else: + # Unified method handles both static and dynamic quantizers + return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) # SequentialQuantizer is required if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled: @@ -799,7 +798,7 @@ def process_layer_quant_config(layer_config_dict): layer_config = {"quant_algo": "W8A16"} elif v == "int8_sq": layer_config = {"quant_algo": "W8A8_SQ_PER_CHANNEL"} - elif v == "nvfp4": + elif v in ["nvfp4", "nvfp4_static"]: layer_config = { "quant_algo": "NVFP4", "group_size": block_size_value, @@ -1397,6 +1396,18 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False for module in modules: module.weight_quantizer[-1].amax = weight_amax + # Handle NVFP4StaticQuantizer: unify global_amax for fused layers + elif isinstance(modules[0].weight_quantizer, NVFP4StaticQuantizer): + global_amax_list = [ + m.weight_quantizer.global_amax + for m in modules + if m.weight_quantizer.global_amax is not None + ] + if global_amax_list: + unified_global_amax = torch.max(torch.stack(global_amax_list)) + for module in modules: + module.weight_quantizer.global_amax = unified_global_amax + elif ( modules[0].weight_quantizer.is_enabled and modules[0].weight_quantizer.amax is not None @@ -1481,6 +1492,22 @@ def get_quant_config( if block_size == 0: block_size = get_weight_block_size(module) + # Static NVFP4 uses pre-computed per-block scales from MSE calibration + if quantization_format == QUANTIZATION_NVFP4: + weight_quantizer = getattr(module, "weight_quantizer", None) + if weight_quantizer is None: + # Try to get from first weight attribute + for wn in weight_names: + weight_quantizer = getattr( + module, quantizer_attr_names(wn).weight_quantizer, None + ) + if weight_quantizer is not None: + break + if weight_quantizer is not None: + is_static = isinstance(weight_quantizer, NVFP4StaticQuantizer) + if is_static: + quantization_format = "nvfp4_static" + # Construct per layer config dictionary layer_config_dict[name + ".quantization"] = quantization_format layer_config_dict[name + ".awq_block_size"] = block_size diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 04775715b4..4871d36b08 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -52,7 +52,11 @@ from torch.distributed.fsdp import FSDPModule from modelopt.torch.quantization import set_quantizer_by_cfg_context -from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer +from modelopt.torch.quantization.nn import ( + NVFP4StaticQuantizer, + SequentialQuantizer, + TensorQuantizer, +) from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names @@ -547,6 +551,7 @@ def _export_quantized_weight( weight, _ = maybe_transpose_expert_weight_dimensions( weight, is_bmm_expert_weight=is_bmm_expert_weight ) + weight_scale = NVFP4QTensor.get_weights_scaling_factor( weight, block_size=block_size, diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index efc66ffa94..88e93bb770 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -255,8 +255,6 @@ def wrapped_calib_func( else: # Direct calibration (existing behavior) func(model, forward_loop=forward_loop, **kwargs) - else: - raise ValueError(f"No calibration function provided for method: {method}") # Lets get the latest metadata for the quantizer states metadata = {} diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index a62d8620b1..3ff7401ec3 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -1346,19 +1346,10 @@ def global_amax(self, value): def _fake_quantize(self, inputs): """Fake quantization using two-level scaling with _amax and _global_amax.""" if self.amax is not None: - # Ensure amax/global_amax are on the same device as inputs. - # After from_pretrained with device_map, quantizer buffers may remain - # on CPU while model weights/activations are on GPU. - amax = self.amax - if amax.device != inputs.device: - amax = amax.to(inputs.device) - global_amax = self.global_amax - if global_amax is not None and global_amax.device != inputs.device: - global_amax = global_amax.to(inputs.device) return static_blockwise_fp4_fake_quant( inputs, - amax, - global_amax, # Can be None, will be computed internally + self.amax, + self.global_amax, # Can be None, will be computed internally True, # quantize_block_scales inputs.dtype, self._pass_through_bwd, diff --git a/modelopt/torch/quantization/triton/__init__.py b/modelopt/torch/quantization/triton/__init__.py index 6e8d4dba11..def70e5914 100644 --- a/modelopt/torch/quantization/triton/__init__.py +++ b/modelopt/torch/quantization/triton/__init__.py @@ -34,10 +34,6 @@ from .fp4_kernel import * from .fp8_kernel import * - # fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv) - if torch.cuda.get_device_capability() >= (8, 9): - from .fp4_kernel_hopper import * - # fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv) if torch.cuda.get_device_capability() >= (8, 9): from .fp4_kernel_hopper import * From 0bde84014463be62a3fa5626178a38ca0b1c7983 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Fri, 6 Mar 2026 01:06:18 +0000 Subject: [PATCH 24/52] minor Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index c58e2d607a..2721e8d3b9 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -242,7 +242,6 @@ def find_quant_cfg_entry_by_path( {"quantizer_name": "*o_proj*", "enable": False}, # Skip QKV Output Projection ] - INT8_DEFAULT_CFG = { "quant_cfg": [ *_base_disable_all, @@ -375,6 +374,7 @@ def find_quant_cfg_entry_by_path( "algorithm": "max", } + INT4_AWQ_CFG = { "quant_cfg": [ *_base_disable_all, From 6c478a7ce40fc0565b89a50d1db7f9b3e5d9b999 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Fri, 6 Mar 2026 22:12:00 +0000 Subject: [PATCH 25/52] tested e2e on qwen Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 15 ++++- modelopt/torch/quantization/config.py | 20 ++++++ modelopt/torch/quantization/mode.py | 4 +- modelopt/torch/quantization/model_calib.py | 73 +++++++++++++++++++--- 4 files changed, 100 insertions(+), 12 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index aac9f8ccbc..835e3b695c 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -24,6 +24,7 @@ import numpy as np import torch from accelerate.hooks import remove_hook_from_module +from eval_perplexity import evaluate_perplexity from example_utils import ( build_quant_cfg, copy_custom_model_files, @@ -110,6 +111,7 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None: "nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG, "nvfp4_omlp_only": mtq.NVFP4_OMLP_ONLY_CFG, "nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG, + "nvfp4_gptq": mtq.NVFP4_GPTQ_CFG, "mxfp8": mtq.MXFP8_DEFAULT_CFG, "nvfp4_local_hessian": mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG, } @@ -951,7 +953,7 @@ def quantize_main( else: # mono quantization - + if args.recipe is not None: print(f"Use recipe {args.recipe} for quantization") recipe = load_recipe(args.recipe) @@ -1029,6 +1031,11 @@ def quantize_main( is_nemotron_vl_model, first_text_speech_dataset, ) + + if args.eval_perplexity and tokenizer is not None: + print("Evaluating Wikitext-2 perplexity...") + evaluate_perplexity(language_model, tokenizer, seq_len=args.calib_seq) + export_quantized( args, full_model, @@ -1187,6 +1194,12 @@ def parse_args() -> argparse.Namespace: default=False, action="store_true", ) + parser.add_argument( + "--eval_perplexity", + help="Evaluate Wikitext-2 perplexity after quantization.", + default=False, + action="store_true", + ) parser.add_argument( "--low_memory_mode", help=( diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 2721e8d3b9..1850771779 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -626,6 +626,25 @@ def _nvfp4_selective_quant_cfg( }, } +NVFP4_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": {"method": "gptq", "use_sequential": True}, +} + MAMBA_MOE_NVFP4_AGGRESSIVE_CFG = { "quant_cfg": [ *_base_disable_all, @@ -816,6 +835,7 @@ def _nvfp4_selective_quant_cfg( "NVFP4_AWQ_FULL_CFG", "NVFP4_AWQ_LITE_CFG", "NVFP4_DEFAULT_CFG", + "NVFP4_GPTQ_CFG", "NVFP4_FP8_MHA_CONFIG", "NVFP4_KV_CFG", "NVFP4_KV_ROTATE_CFG", diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 88e93bb770..df48c72c29 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -242,8 +242,8 @@ def wrapped_calib_func( if sequential: if forward_loop is None: raise ValueError("forward_loop is required for calibration but got None.") - assert method in ["max"], ( - f"Sequential calibration currently only supports max calibration, got {method}" + assert method in ["max", "gptq"], ( + f"Sequential calibration currently only supports max and gptq calibration, got {method}" ) # Wrap with sequential processing sequential_calibrate( diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index ec59d113dc..6d4e06b946 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1899,19 +1899,62 @@ def _layer_forward_loop(m, _inputs=layer_inputs): print_rank_0("Sequential calibration completed") +def _promote_nvfp4_static_quantizers(model: nn.Module) -> int: + """Convert eligible TensorQuantizers to NVFP4StaticQuantizer in-place. + + After max calibration sets per-block amax values, NVFP4 static quantizers + need to be promoted so they use the two-level scaling path (global amax + + per-block amax) instead of the generic E4M3 path. + + Returns the number of quantizers converted. + """ + converted = 0 + for _name, module in list(model.named_modules()): + if isinstance(module, TensorQuantizer) and not module._disabled: + if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): + is_nvfp4_static = ( + module.is_static_block_quant + and module._num_bits == (2, 1) + and module._block_sizes is not None + and module._block_sizes.get("scale_bits") == (4, 3) + ) + if is_nvfp4_static: + initial_amax = module._amax.clone().detach() + global_amax = reduce_amax(initial_amax, axis=None) + NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) + converted += 1 + return converted + + @torch.no_grad() def gptq( layer: nn.Module, - inputs: list[tuple[tuple, dict]], forward_loop: ForwardLoop, percdamp: float = 0.01, block_size: int = 128, **kwargs, ): - """GPTQ quantization - a GPTQ variant.""" - # Set weight amax and activation amax'es for the current layer using max_calibrate + """GPTQ quantization - a GPTQ variant. + + Args: + layer: A single decoder layer to quantize. + forward_loop: Callable that replays calibration inputs through the layer. + Provided by ``sequential_calibrate`` which captures per-layer activations. + percdamp: Percentage of avg Hessian diagonal for damping (default: 0.01). + block_size: Block size for GPTQ weight update. + """ + import time + + total_start = time.time() + + # Set weight amax and activation amax for the current layer using max_calibrate max_calibrate(layer, forward_loop=forward_loop) + # Promote NVFP4 static quantizers so they use the two-level scaling path + n_promoted = _promote_nvfp4_static_quantizers(layer) + if n_promoted: + print_rank_0(f"Promoted {n_promoted} quantizer(s) to NVFP4StaticQuantizer") + # Dictionary to store hessian matrices for all linear layers in this decoder hessian_state = {} @@ -1955,18 +1998,20 @@ def hessian_forward(self, input, *args, **kwargs): bind_forward_method(module, _make_hessian_forward(name), "_forward_no_gptq_hessian") patched_modules.append(module) - # Run forward passes with the provided inputs to collect Hessians - print_rank_0( - f"Computing Hessians for {len(tensor_mapping)} linear layers using {len(inputs)} batches..." - ) - for args, kwargs_input in inputs: - layer(*args, **kwargs_input) + # Run forward passes to collect Hessians + hessian_start = time.time() + print_rank_0(f"Computing Hessians for {len(tensor_mapping)} linear layers...") + forward_loop(layer) # Unpatch forwards for module in patched_modules: unpatch_forward_method(module, "_forward_no_gptq_hessian") + torch.cuda.synchronize() if torch.cuda.is_available() else None + hessian_time = time.time() - hessian_start + # Phase 3: Update weights using computed Hessians (same as gptq_lite) + weight_update_start = time.time() print_rank_0("Updating weights using GPTQ algorithm...") for name, module in layer.named_modules(): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: @@ -1978,3 +2023,13 @@ def hessian_forward(self, input, *args, **kwargs): # Free memory del hessian_state[module.name] torch.cuda.empty_cache() + + torch.cuda.synchronize() if torch.cuda.is_available() else None + weight_update_time = time.time() - weight_update_start + + total_time = time.time() - total_start + print_rank_0( + f"GPTQ timing - Hessian: {hessian_time:.2f}s, " + f"Weight update: {weight_update_time:.2f}s, " + f"Total: {total_time:.2f}s" + ) From 539d3bfe986784a9b506435d85a5aa88fb2ed01a Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Fri, 6 Mar 2026 22:20:44 +0000 Subject: [PATCH 26/52] removed perplexity eval Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 835e3b695c..2b78b3ae4f 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -24,7 +24,6 @@ import numpy as np import torch from accelerate.hooks import remove_hook_from_module -from eval_perplexity import evaluate_perplexity from example_utils import ( build_quant_cfg, copy_custom_model_files, @@ -1032,10 +1031,6 @@ def quantize_main( first_text_speech_dataset, ) - if args.eval_perplexity and tokenizer is not None: - print("Evaluating Wikitext-2 perplexity...") - evaluate_perplexity(language_model, tokenizer, seq_len=args.calib_seq) - export_quantized( args, full_model, @@ -1194,12 +1189,7 @@ def parse_args() -> argparse.Namespace: default=False, action="store_true", ) - parser.add_argument( - "--eval_perplexity", - help="Evaluate Wikitext-2 perplexity after quantization.", - default=False, - action="store_true", - ) + parser.add_argument( "--low_memory_mode", help=( From f38a5cce490131d4ddce5ad61fe994803a4d4877 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Fri, 6 Mar 2026 23:39:59 +0000 Subject: [PATCH 27/52] update Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 2b78b3ae4f..cf39122d8a 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -24,6 +24,7 @@ import numpy as np import torch from accelerate.hooks import remove_hook_from_module +from eval_perplexity import evaluate_perplexity from example_utils import ( build_quant_cfg, copy_custom_model_files, @@ -696,6 +697,9 @@ def export_quantized( "They will be set at deployment time." ) + if getattr(args, "eval_perplexity", False) and tokenizer is not None: + evaluate_perplexity(full_model, tokenizer, seq_len=2048) + # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization if args.vllm_fakequant_export: From e4f7534c94041f7b8e1bce8b9d84a07653954f95 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 16 Mar 2026 18:26:47 +0000 Subject: [PATCH 28/52] revert later Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 123 +++++++++++++++++++++++- modelopt/torch/quantization/__init__.py | 8 +- modelopt/torch/quantization/config.py | 16 +++ 3 files changed, 144 insertions(+), 3 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index cf39122d8a..d9e9902027 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -24,7 +24,6 @@ import numpy as np import torch from accelerate.hooks import remove_hook_from_module -from eval_perplexity import evaluate_perplexity from example_utils import ( build_quant_cfg, copy_custom_model_files, @@ -64,6 +63,11 @@ ) from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration +from modelopt.torch.quantization.metrics import ( + ActivationMSELogger, + compute_perplexity, + get_wikitext2, +) from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights from modelopt.torch.quantization.utils import is_quantized from modelopt.torch.utils.dataset_utils import ( @@ -101,6 +105,7 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None: "int4_awq": mtq.INT4_AWQ_CFG, "w4a8_awq": mtq.W4A8_AWQ_BETA_CFG, "nvfp4": mtq.NVFP4_DEFAULT_CFG, + "nvfp4_wo": mtq.NVFP4_WEIGHT_ONLY_CFG, "nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG, "nvfp4_mse": mtq.NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG, "fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, @@ -111,6 +116,7 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None: "nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG, "nvfp4_omlp_only": mtq.NVFP4_OMLP_ONLY_CFG, "nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG, + "nvfp4_wo_gptq": mtq.NVFP4_WEIGHT_ONLY_GPTQ_CFG, "nvfp4_gptq": mtq.NVFP4_GPTQ_CFG, "mxfp8": mtq.MXFP8_DEFAULT_CFG, "nvfp4_local_hessian": mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG, @@ -698,7 +704,10 @@ def export_quantized( ) if getattr(args, "eval_perplexity", False) and tokenizer is not None: - evaluate_perplexity(full_model, tokenizer, seq_len=2048) + seq_len = getattr(args, "eval_perplexity_seq_len", 2048) + eval_data = get_wikitext2(tokenizer, seq_len) + ppl = compute_perplexity(full_model, eval_data) + print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization @@ -943,6 +952,64 @@ def quantize_main( args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model ) + # Collect original (unquantized) activations before quantization modifies the model + mse_logger = None + if getattr(args, "measure_activation_mse", False): + n_mse = getattr(args, "activation_mse_max_samples", 16) + mse_save_dir = getattr(args, "activation_mse_save_dir", None) + mse_input_path = getattr(args, "activation_mse_input_path", None) + + # Resolve MSE input data: frozen file (raw text or tokenized) or live dataloader + mse_data = None + if mse_input_path is not None: + if mse_input_path.endswith(".json"): + if os.path.isfile(mse_input_path): + print(f"Loading MSE input data from existing .json file: {mse_input_path}") + texts = ActivationMSELogger.load_raw_text(mse_input_path) + mse_data = ActivationMSELogger.tokenize_raw_text( + texts, + tokenizer, + max_length=args.calib_seq, + ) + else: + assert tokenizer is not None, ( + "--activation_mse_input_path with .json requires a tokenizer to decode" + ) + print(f"Creating MSE input data .json file: {mse_input_path}") + texts = ActivationMSELogger.materialize_raw_text( + calib_dataloader, + mse_input_path, + tokenizer=tokenizer, + max_samples=n_mse, + ) + mse_data = ActivationMSELogger.tokenize_raw_text( + texts, + tokenizer, + max_length=args.calib_seq, + ) + elif mse_input_path.endswith(".pt"): + if os.path.isfile(mse_input_path): + print(f"Loading MSE input data from existing .pt file: {mse_input_path}") + mse_data = ActivationMSELogger.load_data(mse_input_path) + else: + print(f"Creating MSE input data .pt file: {mse_input_path}") + mse_data = ActivationMSELogger.materialize_data( + calib_dataloader, + mse_input_path, + max_samples=n_mse, + ) + else: + raise ValueError( + f"--activation_mse_input_path must end with .json or .pt, got: {mse_input_path}" + ) + + if mse_data is None: + mse_data = calib_dataloader + + mse_logger = ActivationMSELogger(max_samples=n_mse, save_dir=mse_save_dir) + print(f"Collecting original (unquantized) activations for MSE over {n_mse} samples...") + mse_logger.collect(language_model, mse_data, phase="original") + if args.auto_quantize_bits: assert len(args.qformat.split(",")) > 1, ( "Auto quantization needs multiple quantization format." @@ -1035,6 +1102,22 @@ def quantize_main( first_text_speech_dataset, ) + if mse_logger is not None: + import gc + + print("Collecting quantized activations for MSE...") + mse_logger.collect(language_model, mse_data, phase="quantized") + + mse_logger.compute_mse() + print(mse_logger.summary()) + + if getattr(args, "activation_mse_save_dir", None): + mse_logger.save() + + del mse_logger, mse_data + gc.collect() + torch.cuda.empty_cache() + export_quantized( args, full_model, @@ -1259,6 +1342,42 @@ def parse_args() -> argparse.Namespace: help="Export as vLLM fake-quant checkpoint (produces vllm_fq_modelopt_state.pth " "for use with vllm_serve_fakequant.py).", ) + parser.add_argument( + "--eval_perplexity_seq_len", + type=int, + default=2048, + help="Sequence length for perplexity evaluation (default: 2048).", + ) + parser.add_argument( + "--measure_activation_mse", + action=argparse.BooleanOptionalAction, + default=False, + help="Measure per-layer activation MSE (original vs quantized) after quantization.", + ) + parser.add_argument( + "--activation_mse_max_samples", + type=int, + default=16, + help="Max calibration samples for activation MSE (default: 16).", + ) + parser.add_argument( + "--activation_mse_save_dir", + type=str, + default=None, + help="Directory to save activation MSE results. If not set, results are only printed.", + ) + parser.add_argument( + "--activation_mse_input_path", + type=str, + default=None, + help=( + "Path to frozen MSE input data. Supports two formats:\n" + " .json — raw text (cross-model reuse): if file exists, loads and re-tokenizes " + "with the current model's tokenizer; if not, decodes calibration data to text and saves.\n" + " .pt — tokenized tensors (same-tokenizer reuse): if file exists, loads directly; " + "if not, materializes from calibration data and saves." + ), + ) args = parser.parse_args() if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0): diff --git a/modelopt/torch/quantization/__init__.py b/modelopt/torch/quantization/__init__.py index 87dbf30bb5..d471e55823 100644 --- a/modelopt/torch/quantization/__init__.py +++ b/modelopt/torch/quantization/__init__.py @@ -16,12 +16,18 @@ """Quantization package.""" # Initialize mode and plugins -from . import mode, plugins, utils +from . import metrics, mode, plugins, utils # Add methods to mtq namespace from .compress import * from .config import * from .conversion import * +from .metrics import ( + ActivationMSELogger, + compute_perplexity, + get_wikitext2, + measure_per_layer_activation_mse, +) from .model_quant import * from .nn.modules.quant_module import QuantModuleRegistry from .utils import update_quant_cfg_with_kv_cache_quant diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 1850771779..08486a2c08 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -626,6 +626,20 @@ def _nvfp4_selective_quant_cfg( }, } +NVFP4_WEIGHT_ONLY_GPTQ_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": {"enable": False}, + **_default_disabled_quantizer_cfg, + }, + "algorithm": {"method": "gptq", "use_sequential": True}, +} + NVFP4_GPTQ_CFG = { "quant_cfg": { "*weight_quantizer": { @@ -836,6 +850,8 @@ def _nvfp4_selective_quant_cfg( "NVFP4_AWQ_LITE_CFG", "NVFP4_DEFAULT_CFG", "NVFP4_GPTQ_CFG", + "NVFP4_WEIGHT_ONLY_CFG", + "NVFP4_WEIGHT_ONLY_GPTQ_CFG", "NVFP4_FP8_MHA_CONFIG", "NVFP4_KV_CFG", "NVFP4_KV_ROTATE_CFG", From 71772461c6d42e4a1a166d7381e38bd8e2f1e530 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Thu, 19 Mar 2026 06:32:45 +0000 Subject: [PATCH 29/52] minor update Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/utils/network.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modelopt/torch/utils/network.py b/modelopt/torch/utils/network.py index b07ca570c4..b54332375b 100644 --- a/modelopt/torch/utils/network.py +++ b/modelopt/torch/utils/network.py @@ -46,7 +46,6 @@ def _convert_to_wrapped_module_name(name: str) -> str: "ModelLike", "compare_dict", "create_param_grad_clear_hook", - "get_decoder_layers", "get_model_attributes", "get_module_device", "get_same_padding", From ccee65996e58d31aa444e79286e9caa8d214e8cc Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 18 Mar 2026 06:35:56 +0000 Subject: [PATCH 30/52] update Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 142 +++++++++++++++++++-- 1 file changed, 129 insertions(+), 13 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 6d4e06b946..a08cb32181 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1660,6 +1660,103 @@ def prepare_hessian_inverse(h, weight, percdamp): return h_inv +def _build_column_qdq(quantizer, weight_shape): + """Build a fast column-wise quantize-dequantize function for integer quantizers. + + Instead of calling the full TensorQuantizer on the entire weight matrix (which + quantizes all elements) and extracting one column, this returns a closure that + quantizes only a single column using the quantizer's pre-computed amax/scales. + + Since max_calibrate fixes the amax before GPTQ weight updates, quantizing a + single column with the same fixed scale gives bit-identical results to + quantizing the full matrix and extracting that column. + + Args: + quantizer: The weight TensorQuantizer (already calibrated). + weight_shape: Shape of the weight tensor (out_features, in_features). + + Returns: + Tuple of (column_qdq_fn, supported) where: + - column_qdq_fn(column, col_idx) -> qdq_column (if supported) + - supported: True if column-wise qdq is available, False to fall back. + """ + # Unsupported: NVFP4 (two-level FP4 scaling), FP quantization (num_bits is a tuple) + if isinstance(quantizer, NVFP4StaticQuantizer): + return None, False + if isinstance(quantizer._num_bits, tuple): + return None, False + + # Unsupported: pre_quant_scale (SmoothQuant) or rotation transforms mix columns + if getattr(quantizer, "pre_quant_scale", None) is not None: + return None, False + if getattr(quantizer, "rotate_is_enabled", False): + return None, False + + # Need calibrated amax + if not hasattr(quantizer, "_amax") or quantizer._amax is None: + return None, False + + num_bits = quantizer._num_bits + unsigned = getattr(quantizer, "_unsigned", False) + narrow_range = getattr(quantizer, "_narrow_range", False) + max_bound = (2 ** (num_bits - 1 + int(unsigned))) - 1 + min_bound = -max_bound + int(narrow_range) + + amax = quantizer._amax.float() + out_features, in_features = weight_shape + + # Determine quantization geometry from block_sizes + block_sizes = quantizer.block_sizes + group_size = None + if block_sizes is not None: + # Skip dynamic block quantization + if block_sizes.get("type", "static") == "dynamic": + return None, False + group_size = block_sizes.get(-1, None) or block_sizes.get(len(weight_shape) - 1, None) + + if group_size is not None and group_size > 0: + # Per-group block quantization along last dim. + # After _setup_for_blockquant, weight is reshaped to (-1, group_size) with axis=(0,). + # amax shape: (out_features * n_groups, 1) where n_groups = in_features // group_size. + if in_features % group_size != 0: + return None, False # Padding case — fall back + + n_groups = in_features // group_size + + try: + # Reshape amax to (out_features, n_groups) for O(1) group lookup + amax_2d = amax.reshape(out_features, n_groups) + except RuntimeError: + return None, False + + def _column_qdq_group( + col, col_idx, _a=amax_2d, _mx=max_bound, _mn=min_bound, _gs=group_size + ): + col_scale = _mx / _a[:, col_idx // _gs].clamp(min=1e-12) + return torch.clamp(torch.round(col * col_scale), _mn, _mx) / col_scale + + return _column_qdq_group, True + + # Per-channel (axis != None) or per-tensor (axis == None) + axis = quantizer.axis + if axis is not None: + # Per-channel: amax has shape (out_features, 1) or similar + col_scale = max_bound / amax.reshape(-1).clamp(min=1e-12) + + def _column_qdq_channel(col, col_idx, _s=col_scale, _mx=max_bound, _mn=min_bound): + return torch.clamp(torch.round(col * _s), _mn, _mx) / _s + + return _column_qdq_channel, True + + # Per-tensor: single scalar scale + scalar_scale = max_bound / amax.clamp(min=1e-12).item() + + def _column_qdq_tensor(col, col_idx, _s=scalar_scale, _mx=max_bound, _mn=min_bound): + return torch.clamp(torch.round(col * _s), _mn, _mx) / _s + + return _column_qdq_tensor, True + + def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None): """Update module weights using GPTQ-style blockwise quantization. @@ -1676,22 +1773,41 @@ def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None): # Preprocess Hessian: handle dead neurons and add damping h_inv = prepare_hessian_inverse(h, weight, percdamp) + # Try to build fast column-wise qdq (avoids quantizing the full matrix per column) + col_qdq_fn, col_qdq_supported = _build_column_qdq(module.weight_quantizer, weight.shape) + # Process weights in blocks for block_start in range(0, num_cols, block_size): block_end = min(block_start + block_size, num_cols) n_cols = block_end - block_start - wblk = weight.clone() - errs = torch.zeros_like(wblk[:, block_start:block_end]) h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end] - for i in range(n_cols): - w_ci = wblk[:, block_start + i] - d = h_inv_cho_blk[i, i] - qdq = module.weight_quantizer(wblk) - weight[:, block_start + i] = qdq[:, block_start + i] - err = (w_ci - qdq[:, block_start + i]) / d - wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) - errs[:, i] = err + if col_qdq_supported: + # Fast path: clone only the block columns, quantize only per-column + wblk = weight[:, block_start:block_end].clone() + errs = torch.zeros_like(wblk) + + for i in range(n_cols): + w_ci = wblk[:, i] + d = h_inv_cho_blk[i, i] + qdq_col = col_qdq_fn(w_ci, block_start + i) + weight[:, block_start + i] = qdq_col + err = (w_ci - qdq_col) / d + wblk[:, i:].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) + errs[:, i] = err + else: + # Fallback: original full-matrix quantization path + wblk = weight.clone() + errs = torch.zeros_like(wblk[:, block_start:block_end]) + + for i in range(n_cols): + w_ci = wblk[:, block_start + i] + d = h_inv_cho_blk[i, i] + qdq = module.weight_quantizer(wblk) + weight[:, block_start + i] = qdq[:, block_start + i] + err = (w_ci - qdq[:, block_start + i]) / d + wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) + errs[:, i] = err # Propagate errors to remaining weights weight[:, block_end:].addmm_(errs, h_inv[block_start:block_end, block_end:], alpha=-1) @@ -1895,7 +2011,7 @@ def _layer_forward_loop(m, _inputs=layer_inputs): torch.cuda.empty_cache() finally: input_getter._unpatch_all_layers() - + print_rank_0("Sequential calibration completed") @@ -2020,9 +2136,9 @@ def hessian_forward(self, input, *args, **kwargs): blockwise_weight_update( module, hessian, block_size, percdamp, n_samples=state["n_samples"] ) - # Free memory del hessian_state[module.name] - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() torch.cuda.synchronize() if torch.cuda.is_available() else None weight_update_time = time.time() - weight_update_start From 3a5d235afe0a3792d0783263e00d181517ac6800 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 18 Mar 2026 23:30:36 +0000 Subject: [PATCH 31/52] gptq faster Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 108 +++++++--- .../quantization/triton/gptq_fused_kernel.py | 189 ++++++++++++++++++ tests/gpu/torch/quantization/test_gptq.py | 93 ++++++++- 3 files changed, 365 insertions(+), 25 deletions(-) create mode 100644 modelopt/torch/quantization/triton/gptq_fused_kernel.py diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index a08cb32181..178347b175 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1596,7 +1596,7 @@ def _print_relative_mse_error( delta = q - w mse = (delta).mm(h).mul(delta).mean() / (w.mm(h).mul(w).mean() + 1e-6) suffix = f", n_hessian_samples: {n_samples}" if n_samples is not None else "" - print(f"[{module_name}] Relative MSE error: {mse.item():.2e}{suffix}") + print_rank_0(f"[{module_name}] Relative MSE error: {mse.item():.2e}{suffix}") def update_hessian(input, hessian, n_samples): @@ -1655,7 +1655,7 @@ def prepare_hessian_inverse(h, weight, percdamp): h = torch.cholesky_inverse(torch.linalg.cholesky(h)) h_inv = torch.linalg.cholesky(h, upper=True) except (RuntimeError, torch.linalg.LinAlgError): - print("Warning: Hessian is not positive definite, using identity matrix") + print_rank_0("Warning: Hessian is not positive definite, using identity matrix") h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype) return h_inv @@ -1757,37 +1757,104 @@ def _column_qdq_tensor(col, col_idx, _s=scalar_scale, _mx=max_bound, _mn=min_bou return _column_qdq_tensor, True +def _can_use_fused_gptq(quantizer) -> bool: + """Check whether the fused Triton GPTQ kernel can be used for *quantizer*.""" + if not isinstance(quantizer, NVFP4StaticQuantizer): + return False + if not hasattr(quantizer, "_amax") or quantizer._amax is None: + return False + from modelopt.torch.quantization.triton import IS_AVAILABLE as _TRITON_OK + + return _TRITON_OK + + def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None): """Update module weights using GPTQ-style blockwise quantization. + Dispatches to one of three internal paths depending on quantizer type: + + 1. **Fused Triton** — for :class:`NVFP4StaticQuantizer` when Triton is + available. Runs the entire column loop in a single GPU kernel per + block (~130x faster than the unfused path on Blackwell GPUs). + 2. **Column-QDQ** — for integer quantizers whose scale geometry allows + single-column fake-quant via :func:`_build_column_qdq`. + 3. **Full-matrix fallback** — calls the quantizer on the full weight matrix + each column (slowest, but always correct). + Args: - module: Neural network module with weight and weight_quantizer - H: Hessian matrix (d x d) - block_size: Size of blocks to process at once - percdamp: Damping percentage for Hessian diagonal - n_samples: Number of Hessian samples for logging (optional) + module: Neural network module with ``weight`` and ``weight_quantizer``. + h: Hessian matrix of shape ``(d, d)``. + block_size: Number of columns processed per block. + percdamp: Damping as a fraction of the mean Hessian diagonal. + n_samples: Number of Hessian samples (used only for logging). """ weight = module.weight.data.float().clone() - _, num_cols = weight.shape + num_rows, num_cols = weight.shape - # Preprocess Hessian: handle dead neurons and add damping h_inv = prepare_hessian_inverse(h, weight, percdamp) - # Try to build fast column-wise qdq (avoids quantizing the full matrix per column) - col_qdq_fn, col_qdq_supported = _build_column_qdq(module.weight_quantizer, weight.shape) + quantizer = module.weight_quantizer + if _can_use_fused_gptq(quantizer): + _blockwise_weight_update_fused(weight, h_inv, quantizer, num_rows, num_cols, block_size) + else: + col_qdq_fn, col_qdq_supported = _build_column_qdq(quantizer, weight.shape) + _blockwise_weight_update_unfused( + weight, h_inv, quantizer, num_cols, block_size, col_qdq_fn, col_qdq_supported + ) + + _print_relative_mse_error(weight, module.weight.float(), h, module.name, n_samples) + module.weight.data = weight.reshape(module.weight.shape).to(module.weight.data.dtype) + + +def _blockwise_weight_update_fused(weight, h_inv, quantizer, num_rows, num_cols, block_size): + """Fused Triton path for NVFP4: one kernel launch per block.""" + from modelopt.torch.quantization.triton.gptq_fused_kernel import gptq_fused_block + + group_size = quantizer.block_sizes.get(-1, None) or quantizer.block_sizes.get(1, None) + num_groups = math.ceil(num_cols / group_size) + amax_grouped = quantizer._amax.float().reshape(num_rows, num_groups).contiguous() + global_amax = quantizer.global_amax.float() - # Process weights in blocks for block_start in range(0, num_cols, block_size): block_end = min(block_start + block_size, num_cols) - n_cols = block_end - block_start + n_cols_blk = block_end - block_start + + w_block = weight[:, block_start:block_end].clone().contiguous() + h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end].contiguous() + + qw_block, err_block = gptq_fused_block( + w_block, + amax_grouped, + global_amax, + h_inv_cho_blk, + group_size, + block_start, + n_cols_blk, + ) + + weight[:, block_start:block_end] = qw_block + if block_end < num_cols: + weight[:, block_end:].addmm_( + err_block[:, :n_cols_blk], + h_inv[block_start:block_end, block_end:], + alpha=-1, + ) + + +def _blockwise_weight_update_unfused( + weight, h_inv, quantizer, num_cols, block_size, col_qdq_fn, col_qdq_supported +): + """Column-QDQ or full-matrix fallback for non-NVFP4 quantizers.""" + for block_start in range(0, num_cols, block_size): + block_end = min(block_start + block_size, num_cols) + n_cols_blk = block_end - block_start h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end] if col_qdq_supported: - # Fast path: clone only the block columns, quantize only per-column wblk = weight[:, block_start:block_end].clone() errs = torch.zeros_like(wblk) - for i in range(n_cols): + for i in range(n_cols_blk): w_ci = wblk[:, i] d = h_inv_cho_blk[i, i] qdq_col = col_qdq_fn(w_ci, block_start + i) @@ -1796,27 +1863,20 @@ def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None): wblk[:, i:].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) errs[:, i] = err else: - # Fallback: original full-matrix quantization path wblk = weight.clone() errs = torch.zeros_like(wblk[:, block_start:block_end]) - for i in range(n_cols): + for i in range(n_cols_blk): w_ci = wblk[:, block_start + i] d = h_inv_cho_blk[i, i] - qdq = module.weight_quantizer(wblk) + qdq = quantizer(wblk) weight[:, block_start + i] = qdq[:, block_start + i] err = (w_ci - qdq[:, block_start + i]) / d wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) errs[:, i] = err - # Propagate errors to remaining weights weight[:, block_end:].addmm_(errs, h_inv[block_start:block_end, block_end:], alpha=-1) - # Print relative mse error - _print_relative_mse_error(weight, module.weight.float(), h, module.name, n_samples) - # Update module weights - module.weight.data = weight.reshape(module.weight.shape).to(module.weight.data.dtype) - def gptq_lite( model: nn.Module, diff --git a/modelopt/torch/quantization/triton/gptq_fused_kernel.py b/modelopt/torch/quantization/triton/gptq_fused_kernel.py new file mode 100644 index 0000000000..21d84713a1 --- /dev/null +++ b/modelopt/torch/quantization/triton/gptq_fused_kernel.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fused Triton kernel for the GPTQ blockwise weight-update inner loop. + +The standard GPTQ inner loop launches ~10-15 CUDA kernels per column +(amax lookup, FP4 quantization, error computation, rank-1 update). +For ``block_size=128`` that is ~1 500 kernel launches per block, each with +~5-10 us of launch overhead dominating actual compute. + +This module fuses the entire inner loop into a **single** Triton kernel per +block. Rows are independent and map to Triton programs; columns are processed +sequentially inside each program so the rank-1 error update is carried forward +without synchronisation. + +Supported quantisation format: **NVFP4 static block quantisation** (two-level +scaling with per-group amax and a global amax). +""" + +import torch +import triton +import triton.language as tl + +__all__ = ["gptq_fused_block"] + +# -- NVFP4 constants used by the kernel ------------------------------------ +# Maximum representable FP4-E2M1 value (1 + 1 + 0.5 = 6.0 when decoded via +# the standard E2M1 table: {0, 0.5, 1, 1.5, 2, 3, 4, 6}). +_FP4_MAX = 6.0 +# FP8-E4M3 has max representable value 448. +_FP8_E4M3_MAX = 448.0 + + +@triton.jit +def _gptq_fused_block_kernel( + w_ptr, # [num_rows, BLOCK_SIZE] working weight block (in-place) + qw_ptr, # [num_rows, BLOCK_SIZE] output: quantized weights + err_ptr, # [num_rows, BLOCK_SIZE] output: quantization errors + amax_ptr, # [num_rows, num_groups] per-group amax, row-major + global_amax_ptr, # scalar float32 on device + hinv_ptr, # [BLOCK_SIZE, BLOCK_SIZE] upper Cholesky of H^{-1} + num_rows, + num_groups, + group_size: tl.constexpr, + block_start, # column offset of this block in the full weight matrix + n_cols, # actual columns in this block (may be < BLOCK_SIZE) + BLOCK_SIZE: tl.constexpr, +): + """One program per row; sequentially quantizes columns, propagating errors.""" + row = tl.program_id(0) + if row >= num_rows: + return + + # Base pointers for this row + w_base = w_ptr + row * BLOCK_SIZE + qw_base = qw_ptr + row * BLOCK_SIZE + err_base = err_ptr + row * BLOCK_SIZE + amax_row_base = amax_ptr + row * num_groups + + # Pre-compute global FP8 scale factors (constant across columns) + global_amax = tl.load(global_amax_ptr).to(tl.float32) + global_scale = global_amax / 6.0 # _FP4_MAX + fp8_inv_scale = tl.where(global_scale > 0.0, 1.0 / (448.0 / global_scale), 0.0) + + j_range = tl.arange(0, BLOCK_SIZE) + + for i in range(BLOCK_SIZE): + wi = tl.load(w_base + i) + + # -- Compute NVFP4 two-level scale for this column's group ----------- + col_idx = block_start + i + group_idx = col_idx // group_size + raw_amax = tl.load(amax_row_base + group_idx).to(tl.float32) + raw_scale = raw_amax / 6.0 # _FP4_MAX + + # FP8-quantize the block scale: scale * fp8_scale -> cast E4M3 -> back + fp8_scale = tl.where(global_scale > 0.0, 448.0 / global_scale, 1.0) + si = (raw_scale * fp8_scale).to(tl.float8e4nv).to(tl.float32) * fp8_inv_scale + + # Guard: replace zero / nan / inf scale with 1.0 + # NOTE: ``si != si`` is the standard NaN check in Triton (no math.isnan). + si_safe = tl.where( + (si == 0.0) | (si != si) | (tl.abs(si) == float("inf")), # noqa: PLR0124 + 1.0, + si, + ) + + # -- FP4-E2M1 fake quantization (nearest-round to 8 levels) ---------- + abs_scaled = tl.abs(wi) / si_safe + q_val = tl.where( + abs_scaled <= 0.25, + 0.0, + tl.where( + abs_scaled < 0.75, + 0.5, + tl.where( + abs_scaled <= 1.25, + 1.0, + tl.where( + abs_scaled < 1.75, + 1.5, + tl.where( + abs_scaled <= 2.5, + 2.0, + tl.where(abs_scaled < 3.5, 3.0, tl.where(abs_scaled <= 5.0, 4.0, 6.0)), + ), + ), + ), + ), + ) + + qi = q_val * si_safe * tl.where(wi >= 0.0, 1.0, -1.0) + tl.store(qw_base + i, qi) + + # -- GPTQ error and rank-1 update ------------------------------------ + di = tl.load(hinv_ptr + i * BLOCK_SIZE + i) + err_i = (wi - qi) / di + tl.store(err_base + i, err_i) + + j_mask = (j_range > i) & (j_range < n_cols) + hinv_row = tl.load(hinv_ptr + i * BLOCK_SIZE + j_range, mask=j_mask, other=0.0) + w_rem = tl.load(w_base + j_range, mask=j_mask, other=0.0) + w_rem = w_rem - err_i * hinv_row + tl.store(w_base + j_range, w_rem, mask=j_mask) + + +def gptq_fused_block( + w_block: torch.Tensor, + amax_grouped: torch.Tensor, + global_amax: torch.Tensor, + h_inv_cho_blk: torch.Tensor, + group_size: int, + block_start: int, + n_cols: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Run the GPTQ column loop for one block in a single Triton kernel launch. + + Args: + w_block: Working weight block of shape ``[num_rows, block_size]`` (will be cloned). + amax_grouped: Per-group amax of shape ``[num_rows, num_groups]``. + global_amax: Scalar tensor with the global amax. + h_inv_cho_blk: Upper Cholesky factor of H^{-1}, shape ``[block_size, block_size]``. + group_size: NVFP4 quantization group size (typically 16). + block_start: Column offset of this block in the full weight matrix. + n_cols: Actual number of columns in this block (``<= block_size``). + + Returns: + Tuple of ``(qw_block, err_block)`` each of shape ``[num_rows, block_size]``. + """ + num_rows, block_size = w_block.shape + num_groups = amax_grouped.shape[1] + + w_block = w_block.contiguous() + amax_grouped = amax_grouped.contiguous() + h_inv_cho_blk = h_inv_cho_blk.contiguous() + + qw_block = torch.empty_like(w_block) + err_block = torch.empty_like(w_block) + + grid = (num_rows,) + with torch.cuda.device(w_block.device): + _gptq_fused_block_kernel[grid]( + w_block, + qw_block, + err_block, + amax_grouped, + global_amax, + h_inv_cho_blk, + num_rows, + num_groups, + group_size, + block_start, + n_cols, + BLOCK_SIZE=block_size, + ) + + return qw_block, err_block diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index ec95e1e8d3..20bdb8f51d 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -21,7 +21,14 @@ import modelopt.torch.quantization as mtq from modelopt.torch.export.unified_export_hf import _export_quantized_weight -from modelopt.torch.quantization.model_calib import blockwise_weight_update, update_hessian +from modelopt.torch.quantization.model_calib import ( + _blockwise_weight_update_fused, + _blockwise_weight_update_unfused, + blockwise_weight_update, + prepare_hessian_inverse, + update_hessian, +) +from modelopt.torch.quantization.nn import NVFP4StaticQuantizer from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor from modelopt.torch.utils.dataset_utils import create_forward_loop, get_dataset_dataloader @@ -293,3 +300,87 @@ def test_gptq_e2e_flow(quant_cfg): print( f"Generated ids after quantization: {tokenizer.decode(generated_ids_after_ptq[0], skip_special_tokens=True)}" ) + + +@pytest.mark.parametrize("dim", [256, 512]) +def test_fused_vs_unfused_nvfp4(dim): + """Verify that the fused Triton GPTQ kernel produces equivalent results to the unfused path. + + The fused kernel computes NVFP4 quantisation inline using Triton intrinsics, + which can differ slightly from the PyTorch-level quantiser path (different FP + rounding order). On real models (dim >= 4096) the relative MSE difference is + typically < 0.1%; at the smaller dims used here the tolerance is set to 20%. + """ + from modelopt.torch.quantization.model_calib import _promote_nvfp4_static_quantizers + + torch.manual_seed(RAND_SEED) + block_size = min(128, dim) + + # NVFP4_WEIGHT_ONLY_GPTQ_CFG uses *static* blocks, which get promoted to + # NVFP4StaticQuantizer — the prerequisite for the fused Triton path. + quant_cfg = copy.deepcopy(mtq.NVFP4_WEIGHT_ONLY_GPTQ_CFG) + quant_cfg["algorithm"] = "max" # calibrate only, don't run GPTQ + + model = torch.nn.Linear(dim, dim, bias=False).to("cuda") + model.name = "test_fused" + original_weight = model.weight.data.clone() + inp = torch.randn(4, 32, dim, device="cuda") + + mtq.quantize(model, quant_cfg, forward_loop=lambda m: m(inp)) + + # Promote to NVFP4StaticQuantizer (normally done by gptq / sequential_calibrate) + n_promoted = _promote_nvfp4_static_quantizers(model) + assert n_promoted > 0, "Expected at least one quantizer to be promoted" + + quantizer = model.weight_quantizer + assert isinstance(quantizer, NVFP4StaticQuantizer), ( + f"Expected NVFP4StaticQuantizer, got {type(quantizer).__name__}" + ) + + # Restore original weight and compute Hessian + model.weight.data = original_weight.clone() + hessian = torch.zeros(dim, dim, dtype=torch.float32) + n_samples = 0 + hessian, n_samples = update_hessian(inp, hessian, n_samples) + hessian = hessian.to("cuda") + + # --- Run fused path --- + weight_fused = original_weight.float().clone() + num_rows, num_cols = weight_fused.shape + h_inv = prepare_hessian_inverse(hessian, weight_fused, percdamp=0.01) + _blockwise_weight_update_fused(weight_fused, h_inv, quantizer, num_rows, num_cols, block_size) + + # --- Run unfused path --- + weight_unfused = original_weight.float().clone() + h_inv_unfused = prepare_hessian_inverse(hessian, weight_unfused, percdamp=0.01) + _blockwise_weight_update_unfused( + weight_unfused, h_inv_unfused, quantizer, num_cols, block_size, None, False + ) + + # Both paths must produce non-trivial updates + assert not torch.equal(weight_fused, original_weight.float()), ( + "Fused path did not update weights" + ) + assert not torch.equal(weight_unfused, original_weight.float()), ( + "Unfused path did not update weights" + ) + + # Compare Hessian-weighted relative MSE + def _relative_mse(q, w, h): + delta = q - w + return (delta.mm(h).mul(delta).mean() / (w.mm(h).mul(w).mean() + 1e-6)).item() + + orig_f = original_weight.float() + mse_fused = _relative_mse(weight_fused, orig_f, hessian) + mse_unfused = _relative_mse(weight_unfused, orig_f, hessian) + + assert mse_fused > 0, "Fused MSE should be positive" + assert mse_unfused > 0, "Unfused MSE should be positive" + + # At small test dimensions, inline Triton FP4 rounding can diverge up to ~15% + # from the PyTorch path. On production-scale layers this drops below 0.1%. + relative_mse_diff = abs(mse_fused - mse_unfused) / max(mse_fused, mse_unfused) + assert relative_mse_diff < 0.20, ( + f"Fused ({mse_fused:.6e}) and unfused ({mse_unfused:.6e}) MSE differ by " + f"{relative_mse_diff:.2%}, expected < 20%" + ) From d4c8a11b33ddde0162335bb02461fd25faceccc6 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Fri, 20 Mar 2026 23:08:13 +0000 Subject: [PATCH 32/52] added metrics files, remove later Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/example_utils.py | 3 + .../torch/quantization/metrics/__init__.py | 28 + .../quantization/metrics/activation_mse.py | 831 ++++++++++++++++++ .../torch/quantization/metrics/perplexity.py | 81 ++ modelopt/torch/quantization/model_calib.py | 6 +- 5 files changed, 948 insertions(+), 1 deletion(-) create mode 100644 modelopt/torch/quantization/metrics/__init__.py create mode 100644 modelopt/torch/quantization/metrics/activation_mse.py create mode 100644 modelopt/torch/quantization/metrics/perplexity.py diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 2a851de5c6..4af91820ea 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -567,6 +567,9 @@ def get_model( try: hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs) + if not hasattr(hf_config, "moe_latent_size"): + hf_config.moe_latent_size = None + if is_nemotron_vl(hf_config): print( "Detected Nemotron VL model from config. " diff --git a/modelopt/torch/quantization/metrics/__init__.py b/modelopt/torch/quantization/metrics/__init__.py new file mode 100644 index 0000000000..a1c737c3c0 --- /dev/null +++ b/modelopt/torch/quantization/metrics/__init__.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +"""Metrics for evaluating quantized models.""" + +from .activation_mse import ActivationMSELogger, measure_per_layer_activation_mse +from .perplexity import compute_perplexity, get_wikitext2 + +__all__ = [ + "ActivationMSELogger", + "compute_perplexity", + "get_wikitext2", + "measure_per_layer_activation_mse", +] diff --git a/modelopt/torch/quantization/metrics/activation_mse.py b/modelopt/torch/quantization/metrics/activation_mse.py new file mode 100644 index 0000000000..1b60977ee1 --- /dev/null +++ b/modelopt/torch/quantization/metrics/activation_mse.py @@ -0,0 +1,831 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors +# ruff: noqa: D107, D205, PERF401, PLR0124 + +"""Per-layer activation MSE between original (unquantized) and quantized model. + +Includes the portable ``ActivationMSELogger`` class that works across codebases +(FP-Quant List[Tensor] style *and* ModelOpt DataLoader-of-dicts style). + +Ported from FP-Quant: https://github.com/IST-DASLab/FP-Quant +""" + +import fnmatch +import gc +import hashlib +import json +import os +from datetime import datetime + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + + +def _get_module(block: nn.Module, name: str) -> nn.Module: + """Get submodule from block by dotted name, e.g. 'self_attn.q_proj'.""" + obj = block + for part in name.split("."): + obj = getattr(obj, part) + return obj + + +def _get_linear_layer_names(block: nn.Module) -> list[str]: + """Collect relative names of linear layers in a transformer block (same as GPTQ).""" + names = [] + for name, layer in block.named_modules(): + if isinstance(layer, nn.Linear): + names.append(name) + return names + + +def _tensor_from_output(out) -> torch.Tensor: + """Extract a single tensor from layer output (handle tuple return).""" + if isinstance(out, torch.Tensor): + return out.detach() + return out[0].detach() + + +def _discover_layer_keys(blocks, layer_names, num_blocks): + """Build list of valid layer keys.""" + keys = [] + for i in range(num_blocks): + for name in layer_names: + try: + _get_module(blocks[i], name) + except AttributeError: + continue + keys.append(f"model.layers.{i}.{name}") + return keys + + +def _collect_outputs( + model: nn.Module, + blocks: nn.ModuleList, + layer_names: list[str], + layer_keys: list[str], + calibration_data: list[torch.Tensor], + device: torch.device | str, + num_blocks: int, + desc: str, +) -> dict[str, list[torch.Tensor]]: + """Run model on calibration data, capture per-layer outputs (moved to CPU).""" + captured: dict[str, torch.Tensor] = {} + saved: dict[str, list[torch.Tensor]] = {k: [] for k in layer_keys} + + def make_hook(key: str): + def hook(_module: nn.Module, _input: tuple, output) -> None: + captured[key] = _tensor_from_output(output).cpu() + + return hook + + hooks = [] + for i in range(num_blocks): + for name in layer_names: + key = f"model.layers.{i}.{name}" + if key not in saved: + continue + try: + mod = _get_module(blocks[i], name) + except AttributeError: + continue + hooks.append(mod.register_forward_hook(make_hook(key))) + + try: + for sample in tqdm(calibration_data, desc=desc, leave=False): + inp = sample.unsqueeze(0) if sample.dim() == 1 else sample + inp = inp.to(device) + captured.clear() + with torch.no_grad(): + _ = model(inp) + for key in layer_keys: + if key in captured: + saved[key].append(captured[key]) + finally: + for h in hooks: + h.remove() + return saved + + +@torch.no_grad() +def measure_per_layer_activation_mse( + model_orig: nn.Module, + model_quant: nn.Module, + calibration_data: list[torch.Tensor], + device: torch.device | str, + log_wandb: bool = False, + max_samples: int | None = None, +) -> dict[str, float]: + """Measure per-linear-layer MSE between activations of the original (unquantized) + model and the quantized model on the same calibration data. + + Runs each model on GPU one at a time to avoid OOM. + Returns a dict mapping layer key (e.g. "model.layers.0.self_attn.q_proj") to MSE. + """ + if max_samples is not None and max_samples > 0: + calibration_data = calibration_data[:max_samples] + + blocks_quant = model_quant.model.layers + blocks_orig = model_orig.model.layers + num_blocks = len(blocks_quant) + assert len(blocks_orig) == num_blocks + + layer_names = _get_linear_layer_names(blocks_quant[0]) + layer_keys = _discover_layer_keys(blocks_quant, layer_names, num_blocks) + + # --- Phase 1: run quantized model on GPU, save outputs to CPU --- + print(" Phase 1/2: collecting quantized model outputs...") + model_quant.to(device) + quant_outputs = _collect_outputs( + model_quant, + blocks_quant, + layer_names, + layer_keys, + calibration_data, + device, + num_blocks, + desc="Activation MSE (quant)", + ) + # Free GPU for original model + model_quant.cpu() + gc.collect() + torch.cuda.empty_cache() + + # --- Phase 2: run original model on GPU, compute MSE vs stored quant --- + print(" Phase 2/2: collecting original model outputs and computing MSE...") + model_orig.to(device) + + # Instead of storing orig outputs, compute MSE on the fly per sample + sum_sq: dict[str, float] = dict.fromkeys(layer_keys, 0.0) + count: dict[str, int] = dict.fromkeys(layer_keys, 0) + + captured: dict[str, torch.Tensor] = {} + + def make_hook(key: str): + def hook(_module: nn.Module, _input: tuple, output) -> None: + captured[key] = _tensor_from_output(output).cpu() + + return hook + + hooks = [] + for i in range(num_blocks): + for name in layer_names: + key = f"model.layers.{i}.{name}" + if key not in sum_sq: + continue + try: + mod = _get_module(blocks_orig[i], name) + except AttributeError: + continue + hooks.append(mod.register_forward_hook(make_hook(key))) + + try: + for sample_idx, sample in enumerate( + tqdm(calibration_data, desc="Activation MSE (orig)", leave=False) + ): + inp = sample.unsqueeze(0) if sample.dim() == 1 else sample + inp = inp.to(device) + captured.clear() + _ = model_orig(inp) + for key in layer_keys: + if key not in captured: + continue + if sample_idx >= len(quant_outputs.get(key, [])): + continue + o = captured[key].float() + q = quant_outputs[key][sample_idx].float() + if o.shape != q.shape: + continue + sum_sq[key] += F.mse_loss(o, q, reduction="sum").item() + count[key] += o.numel() + finally: + for h in hooks: + h.remove() + + # Free original model from GPU + model_orig.cpu() + gc.collect() + torch.cuda.empty_cache() + + # Move quantized model back to GPU for downstream usage + model_quant.to(device) + + mse = { + key: (sum_sq[key] / count[key]) if count[key] > 0 else float("nan") for key in layer_keys + } + + if log_wandb: + try: + import wandb + + for key, val in mse.items(): + if val == val: # skip nan + wandb.log({f"activation_mse/{key}": val}) + except ImportError: + pass + + return mse + + +# --------------------------------------------------------------------------- +# Portable ActivationMSELogger class +# --------------------------------------------------------------------------- + + +def _matches_filter(name: str, layer_filter: str | None) -> bool: + """Check if a layer name matches the optional filter pattern (fnmatch-style).""" + if layer_filter is None: + return True + return fnmatch.fnmatch(name, layer_filter) + + +def _portable_discover_target_layers( + model: nn.Module, + layer_filter: str | None = None, +) -> dict[str, nn.Module]: + """Discover linear layers in decoder blocks with a portable fallback chain. + + Strategy: + 1. Try modelopt's ``get_decoder_layers`` (available inside ModelOpt). + 2. Try common HuggingFace attribute paths (``model.model.layers``, etc.). + 3. Fall back to scanning **all** ``nn.Linear`` in ``model.named_modules()``. + + Within each set of decoder blocks the function collects every ``nn.Linear`` + sub-module and optionally filters by *layer_filter* (fnmatch pattern). + """ + decoder_layers = None + + # 1. Try modelopt helper + try: + from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector + + decoder_layers = LayerActivationCollector.get_decoder_layers(model) + except Exception: + pass + + # 2. Try common HF / other patterns + if decoder_layers is None: + for attr_chain in ( + ("model", "layers"), + ("decoder", "layers"), + ("transformer", "h"), + ("backbone", "layers"), + ): + obj = model + try: + for attr in attr_chain: + obj = getattr(obj, attr) + if isinstance(obj, nn.ModuleList): + decoder_layers = obj + break + except AttributeError: + continue + + targets: dict[str, nn.Module] = {} + + if decoder_layers is not None: + module_to_name: dict[int, str] = {id(m): n for n, m in model.named_modules()} + for block in decoder_layers: + block_name = module_to_name.get(id(block), "") + for sub_name, sub_mod in block.named_modules(): + if isinstance(sub_mod, nn.Linear): + full_name = f"{block_name}.{sub_name}" if block_name else sub_name + if _matches_filter(full_name, layer_filter): + targets[full_name] = sub_mod + else: + # 3. Fallback: all linear layers + for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + if _matches_filter(name, layer_filter): + targets[name] = module + + return targets + + +class ActivationMSELogger: + """Portable activation MSE logger for comparing original vs quantized models. + + Works with both: + + - ``List[Tensor]`` data (**FP-Quant** style): each element is ``[1, seq_len]`` + or ``[B, seq_len]``, consumed via ``model(tensor)``. + - ``DataLoader`` / ``Iterable`` yielding dicts (**ModelOpt** style): + ``{"input_ids": tensor, ...}``, consumed via ``model(**batch)``. + + Guarantees same samples are used for both phases via SHA-256 hashing of + input tensors. Supports saving / loading all activations to disk for + later cross-codebase comparison. + + Example (FP-Quant -- List[Tensor]):: + + mse_logger = ActivationMSELogger(max_samples=16, save_dir="./mse_logs") + mse_logger.collect(model_orig, calibration_data, phase="original") + mse_logger.collect(model_quant, calibration_data, phase="quantized") + results = mse_logger.compute_mse() + print(mse_logger.summary()) + mse_logger.save() + + Example (ModelOpt -- DataLoader with dict batches):: + + mse_logger = ActivationMSELogger(max_samples=16, save_dir="./mse_logs") + mse_logger.collect(model, dataloader, phase="original") + model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + mse_logger.collect(model, dataloader, phase="quantized") + results = mse_logger.compute_mse() + print(mse_logger.summary()) + mse_logger.save() + """ + + def __init__( + self, + max_samples: int = 16, + layer_filter: str | None = None, + save_dir: str | None = None, + ): + self.max_samples = max_samples + self.layer_filter = layer_filter + self.save_dir = save_dir + + # Per-phase state + self.original_activations: dict[str, list[torch.Tensor]] = {} + self.quantized_activations: dict[str, list[torch.Tensor]] = {} + self.input_hashes: list[str] = [] # hashes for "original" phase + self.quant_input_hashes: list[str] = [] # hashes for "quantized" phase + + # Computed after both phases + self.mse_results: dict[str, float] | None = None + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + @torch.no_grad() + def collect( + self, + model: nn.Module, + data, + phase: str, + target_modules: dict[str, nn.Module] | None = None, + ) -> None: + """Collect per-linear-layer output activations for a given phase. + + Args: + model: The model to run (original or quantized). + data: An iterable of batches. Each batch can be: + + - ``torch.Tensor`` with shape ``[B, seq_len]`` (FP-Quant style). + - ``dict`` with at least an ``"input_ids"`` key (ModelOpt style). + - ``list`` / ``tuple`` of tensors. + phase: ``"original"`` or ``"quantized"``. + target_modules: Optional explicit mapping of ``{name: nn.Module}`` + to attach hooks to. If *None*, layers are auto-discovered + via decoder-block scanning. + """ + if phase not in ("original", "quantized"): + raise ValueError(f"phase must be 'original' or 'quantized', got {phase!r}") + + was_training = model.training + model.eval() + + # ----- layer discovery ----- + targets = ( + target_modules + if target_modules is not None + else (_portable_discover_target_layers(model, self.layer_filter)) + ) + if not targets: + raise ValueError( + "No linear layers found. Provide target_modules explicitly or " + f"check layer_filter={self.layer_filter!r}." + ) + + print( + f"[ActivationMSELogger] Phase '{phase}': hooking {len(targets)} layers, " + f"max_samples={self.max_samples}" + ) + + # ----- storage ----- + saved: dict[str, list[torch.Tensor]] = {name: [] for name in targets} + captured: dict[str, torch.Tensor] = {} + hashes: list[str] = [] + + def _make_hook(key: str): + def hook(_module: nn.Module, _input, output) -> None: + captured[key] = _tensor_from_output(output).cpu() + + return hook + + hooks = [] + for name, module in targets.items(): + hooks.append(module.register_forward_hook(_make_hook(name))) + + try: + n_batches = 0 + for batch in tqdm(data, desc=f"Collecting ({phase})", leave=False): + if self.max_samples is not None and n_batches >= self.max_samples: + break + + captured.clear() + self._run_batch(model, batch) + + for name in targets: + if name in captured: + saved[name].append(captured[name]) + + hashes.append(self._hash_batch(batch)) + n_batches += 1 + finally: + for h in hooks: + h.remove() + + model.train(was_training) + + # ----- store results on self ----- + if phase == "original": + self.original_activations = saved + self.input_hashes = hashes + else: + self.quantized_activations = saved + self.quant_input_hashes = hashes + # Verify sample consistency + if self.input_hashes: + self._verify_hashes() + + # Invalidate any previous MSE since we have new activations + self.mse_results = None + + print(f"[ActivationMSELogger] Collected {n_batches} batches for phase '{phase}'") + + def compute_mse(self) -> dict[str, float]: + """Compute per-layer MSE between original and quantized activations. + + Returns: + Dict mapping layer name to its MSE value. + + Raises: + ValueError: If either phase has not been collected yet. + """ + if not self.original_activations: + raise ValueError( + "No original activations collected. Call collect(..., phase='original') first." + ) + if not self.quantized_activations: + raise ValueError( + "No quantized activations collected. Call collect(..., phase='quantized') first." + ) + + common_keys = sorted( + set(self.original_activations.keys()) & set(self.quantized_activations.keys()) + ) + if not common_keys: + raise ValueError( + "No matching layer names between original and quantized activations. " + "Ensure the same model architecture / layer_filter is used for both phases." + ) + + orig_only = set(self.original_activations.keys()) - set(self.quantized_activations.keys()) + quant_only = set(self.quantized_activations.keys()) - set(self.original_activations.keys()) + if orig_only: + print( + f"[ActivationMSELogger] Warning: {len(orig_only)} layers only in original (skipped)" + ) + if quant_only: + print( + f"[ActivationMSELogger] Warning: {len(quant_only)} layers only in quantized (skipped)" + ) + + sum_sq: dict[str, float] = dict.fromkeys(common_keys, 0.0) + count: dict[str, int] = dict.fromkeys(common_keys, 0) + + for name in common_keys: + orig_list = self.original_activations[name] + quant_list = self.quantized_activations[name] + n = min(len(orig_list), len(quant_list)) + for i in range(n): + o = orig_list[i].float() + q = quant_list[i].float() + if o.shape != q.shape: + print( + f"[ActivationMSELogger] Warning: shape mismatch for {name} " + f"batch {i}: {o.shape} vs {q.shape}, skipping" + ) + continue + sum_sq[name] += F.mse_loss(o, q, reduction="sum").item() + count[name] += o.numel() + + self.mse_results = { + key: (sum_sq[key] / count[key]) if count[key] > 0 else float("nan") + for key in common_keys + } + return self.mse_results + + def save(self, path: str | None = None) -> str: + """Save all state (activations, hashes, MSE) to disk via ``torch.save``. + + Args: + path: Explicit file path. If *None*, a timestamped file is created + inside ``self.save_dir`` (which must be set). + + Returns: + The path where the file was saved. + """ + if path is None: + if self.save_dir is None: + raise ValueError("Provide a path or set save_dir in the constructor.") + os.makedirs(self.save_dir, exist_ok=True) + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + path = os.path.join(self.save_dir, f"activation_mse_{ts}.pt") + + payload = { + "max_samples": self.max_samples, + "layer_filter": self.layer_filter, + "input_hashes": self.input_hashes, + "quant_input_hashes": self.quant_input_hashes, + "original_activations": self.original_activations, + "quantized_activations": self.quantized_activations, + "mse": self.mse_results, + } + torch.save(payload, path) + print(f"[ActivationMSELogger] Saved to {path}") + return path + + @classmethod + def load(cls, path: str) -> "ActivationMSELogger": + """Load a previously saved ``ActivationMSELogger`` from disk. + + Args: + path: Path to the ``.pt`` file created by :meth:`save`. + + Returns: + A new ``ActivationMSELogger`` instance with restored state. + """ + payload = torch.load(path, map_location="cpu", weights_only=False) + logger = cls( + max_samples=payload.get("max_samples", 16), + layer_filter=payload.get("layer_filter"), + ) + logger.original_activations = payload.get("original_activations", {}) + logger.quantized_activations = payload.get("quantized_activations", {}) + logger.input_hashes = payload.get("input_hashes", []) + logger.quant_input_hashes = payload.get("quant_input_hashes", []) + logger.mse_results = payload.get("mse") + print(f"[ActivationMSELogger] Loaded from {path}") + return logger + + def summary(self) -> str: + """Return a formatted string summarising per-layer MSE results. + + Computes MSE first if not already done. + """ + if self.mse_results is None: + self.compute_mse() + assert self.mse_results is not None + + lines = ["Per-layer activation MSE (original vs quantized):"] + for key in sorted(self.mse_results.keys()): + lines.append(f" {key}: {self.mse_results[key]:.6e}") + return "\n".join(lines) + + # ------------------------------------------------------------------ + # Pre-materialized MSE data (cross-run / cross-codebase safety) + # ------------------------------------------------------------------ + + @staticmethod + def materialize_data( + data, + path: str, + max_samples: int | None = None, + ) -> list[torch.Tensor]: + """Freeze the first *max_samples* batches from *data* into a ``.pt`` file. + + Each batch (``dict``, ``Tensor``, or ``list/tuple``) is normalised to a + single ``input_ids`` CPU tensor before saving. The resulting file is a + plain ``List[Tensor]`` that can be loaded in **any** codebase and passed + straight to :meth:`collect`. + + If *path* already exists it is **not** overwritten -- call + :meth:`load_data` instead. + + Args: + data: Iterable of batches (DataLoader, List[Tensor], etc.). + path: Destination ``.pt`` file path. + max_samples: How many batches to keep. ``None`` means all. + + Returns: + The materialised list of CPU tensors (same object that was saved). + """ + samples: list[torch.Tensor] = [] + for batch in data: + if max_samples is not None and len(samples) >= max_samples: + break + if isinstance(batch, dict): + t = batch.get("input_ids", next(iter(batch.values()))) + elif isinstance(batch, torch.Tensor): + t = batch + elif isinstance(batch, (list, tuple)): + t = batch[0] + else: + raise TypeError(f"Unsupported batch type: {type(batch)}") + samples.append(t.cpu()) + + os.makedirs(os.path.dirname(path) or ".", exist_ok=True) + torch.save(samples, path) + print(f"[ActivationMSELogger] Materialised {len(samples)} MSE input samples -> {path}") + return samples + + @staticmethod + def load_data(path: str) -> list[torch.Tensor]: + """Load a previously materialised MSE input set. + + Args: + path: Path to the ``.pt`` file created by :meth:`materialize_data`. + + Returns: + ``List[Tensor]`` of input batches (on CPU). + """ + samples = torch.load(path, map_location="cpu", weights_only=True) + print(f"[ActivationMSELogger] Loaded {len(samples)} MSE input samples from {path}") + return samples + + # ------------------------------------------------------------------ + # Raw-text materialization (cross-model / cross-tokenizer reuse) + # ------------------------------------------------------------------ + + @staticmethod + def materialize_raw_text( + data, + path: str, + tokenizer=None, + max_samples: int | None = None, + ) -> list[str]: + """Save raw text strings to a JSON file for cross-model reuse. + + Extracts text from batches by decoding ``input_ids`` with the provided + *tokenizer*. The saved JSON file can be loaded by any model regardless + of its vocabulary and re-tokenized via :meth:`tokenize_raw_text`. + + Args: + data: Iterable of batches (DataLoader, ``List[Tensor]``, etc.). + path: Destination ``.json`` file path. + tokenizer: A HuggingFace tokenizer with a ``decode`` method. + Required to convert token IDs back to text. + max_samples: How many batches to keep. ``None`` means all. + + Returns: + The list of decoded text strings (same content that was saved). + """ + if tokenizer is None: + raise ValueError( + "tokenizer is required for materialize_raw_text to decode input_ids back to text." + ) + + texts: list[str] = [] + for batch in data: + if max_samples is not None and len(texts) >= max_samples: + break + if isinstance(batch, dict): + t = batch.get("input_ids", next(iter(batch.values()))) + elif isinstance(batch, torch.Tensor): + t = batch + elif isinstance(batch, (list, tuple)): + t = batch[0] + else: + raise TypeError(f"Unsupported batch type: {type(batch)}") + + if t.dim() == 1: + t = t.unsqueeze(0) + for row in t: + if max_samples is not None and len(texts) >= max_samples: + break + texts.append(tokenizer.decode(row, skip_special_tokens=True)) + + os.makedirs(os.path.dirname(path) or ".", exist_ok=True) + payload = {"texts": texts, "max_samples": len(texts)} + with open(path, "w", encoding="utf-8") as f: + json.dump(payload, f, ensure_ascii=False, indent=2) + + print(f"[ActivationMSELogger] Saved {len(texts)} raw text samples -> {path}") + return texts + + @staticmethod + def load_raw_text(path: str) -> list[str]: + """Load raw text strings from a JSON file created by :meth:`materialize_raw_text`. + + Args: + path: Path to the ``.json`` file. + + Returns: + List of raw text strings. + """ + with open(path, encoding="utf-8") as f: + payload = json.load(f) + texts = payload["texts"] + print(f"[ActivationMSELogger] Loaded {len(texts)} raw text samples from {path}") + return texts + + @staticmethod + def tokenize_raw_text( + texts: list[str], + tokenizer, + max_length: int = 2048, + ) -> list[torch.Tensor]: + """Tokenize raw text strings into a ``List[Tensor]`` for :meth:`collect`. + + Each string is independently tokenized and truncated to *max_length*. + Returns one ``[1, seq_len]`` tensor per string — the same format + expected by :meth:`collect` and :func:`compute_perplexity`. + + Args: + texts: List of raw text strings (from :meth:`load_raw_text`). + tokenizer: A HuggingFace tokenizer. + max_length: Maximum token length per sample (default: 2048). + + Returns: + ``List[Tensor]`` of tokenized inputs on CPU. + """ + samples: list[torch.Tensor] = [] + for text in texts: + encoded = tokenizer( + text, + return_tensors="pt", + max_length=max_length, + truncation=True, + add_special_tokens=False, + ) + samples.append(encoded.input_ids.cpu()) + print(f"[ActivationMSELogger] Tokenized {len(samples)} samples (max_length={max_length})") + return samples + + # ------------------------------------------------------------------ + # Static / private helpers + # ------------------------------------------------------------------ + + @staticmethod + def _run_batch(model: nn.Module, batch) -> None: + """Run a single batch through the model (handles Tensor, dict, list/tuple). + + Automatically moves inputs to the model's device so that CPU-stored + materialized data works transparently with a CUDA model. + """ + device = next(model.parameters()).device + if isinstance(batch, dict): + batch = { + k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() + } + model(**batch) + elif isinstance(batch, torch.Tensor): + model(batch.to(device)) + elif isinstance(batch, (list, tuple)): + batch = tuple(t.to(device) if isinstance(t, torch.Tensor) else t for t in batch) + model(*batch) + else: + raise TypeError(f"Unsupported batch type: {type(batch)}") + + @staticmethod + def _hash_batch(batch) -> str: + """Compute SHA-256 hash of the primary input tensor in *batch*. + + - ``dict`` -> hashes ``batch["input_ids"]`` (falls back to first value). + - ``Tensor`` -> hashes the tensor directly. + - ``list/tuple`` -> hashes the first element. + """ + if isinstance(batch, dict): + t = batch.get("input_ids", next(iter(batch.values()))) + elif isinstance(batch, torch.Tensor): + t = batch + elif isinstance(batch, (list, tuple)): + t = batch[0] if batch else None + else: + return "" + + if t is None or not isinstance(t, torch.Tensor): + return "" + return hashlib.sha256(t.cpu().contiguous().numpy().tobytes()).hexdigest() + + def _verify_hashes(self) -> None: + """Compare input hashes between original and quantized phases.""" + n = min(len(self.input_hashes), len(self.quant_input_hashes)) + mismatches = sum(1 for i in range(n) if self.input_hashes[i] != self.quant_input_hashes[i]) + if mismatches: + print( + f"[ActivationMSELogger] WARNING: {mismatches}/{n} batches have " + f"different input hashes between original and quantized phases. " + f"The same data may not have been used for both phases!" + ) + else: + print(f"[ActivationMSELogger] Input hash verification passed ({n}/{n} match)") diff --git a/modelopt/torch/quantization/metrics/perplexity.py b/modelopt/torch/quantization/metrics/perplexity.py new file mode 100644 index 0000000000..2b592914ae --- /dev/null +++ b/modelopt/torch/quantization/metrics/perplexity.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors +# ruff: noqa: D103, PERF401 + +"""Perplexity evaluation for language models. + +Ported from FP-Quant: https://github.com/IST-DASLab/FP-Quant +""" + +import torch +import torch.nn.functional as F +from tqdm import trange + + +@torch.no_grad() +def compute_perplexity(model, data, batch_size: int = 1): + num_samples = len(data) + device = next(model.parameters()).device + # Running estimate of negative log-likelihood + nll_running = 0 + # Number of tokens processed to far + tokens_processed = 0 + # Loop through each batch + for i in trange(0, num_samples, batch_size, desc="Computing perplexity", leave=False): + j = min(i + batch_size, num_samples) + inputs = torch.cat(data[i:j]).to(device) + # Forward pass through the model + lm_logits = model(inputs).logits + # Shift logits and labels for next token prediction + shift_logits = lm_logits[:, :-1, :].contiguous() + shift_labels = inputs[:, 1:] + # Compute loss + loss = F.cross_entropy( + shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1) + ) + # Calculate negative log likelihood + a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) + b = tokens_processed / (tokens_processed + shift_labels.numel()) + nll_running = a * loss + b * nll_running + # Update number of processed tokens + tokens_processed += shift_labels.numel() + # Compute perplexity + ppl = nll_running.exp().item() + return ppl + + +def get_wikitext2(tokenizer, sequence_length: int): + """Load WikiText-2 test set as a list of tokenized sequences for perplexity evaluation. + + Args: + tokenizer: HuggingFace tokenizer. + sequence_length: Length of each evaluation sequence. + + Returns: + List of tensors, each of shape ``[1, sequence_length]``. + """ + from datasets import load_dataset + + test_dataset_raw = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + test_dataset_tok = tokenizer( + "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" + ).input_ids + num_test_sequences = test_dataset_tok.numel() // sequence_length + test_loader = [] + for i in range(num_test_sequences): + test_loader.append(test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length]) + return test_loader diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 178347b175..343ca01910 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -2157,8 +2157,12 @@ def gptq( def _make_hessian_forward(module_name): def hessian_forward(self, input, *args, **kwargs): inp = input.to_local() if hasattr(input, "to_local") else input + if self.input_quantizer is not None and self.input_quantizer.is_enabled: + hessian_input = self.input_quantizer(inp) + else: + hessian_input = inp state = hessian_state[module_name] - hessian, n_samples = update_hessian(inp, state["hessian"], state["n_samples"]) + hessian, n_samples = update_hessian(hessian_input, state["hessian"], state["n_samples"]) hessian_state[module_name] = {"hessian": hessian, "n_samples": n_samples} self.weight_quantizer.disable() From 2129511a69ab350f62bf22d13dd0e4b4594c34a9 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Sat, 21 Mar 2026 00:12:21 +0000 Subject: [PATCH 33/52] claude review Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/example_utils.py | 3 -- modelopt/torch/quantization/config.py | 21 ++++---------- modelopt/torch/quantization/model_calib.py | 21 ++++++++++++-- tests/gpu/torch/quantization/test_gptq.py | 33 +++++++++++++--------- 4 files changed, 44 insertions(+), 34 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 4af91820ea..2a851de5c6 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -567,9 +567,6 @@ def get_model( try: hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs) - if not hasattr(hf_config, "moe_latent_size"): - hf_config.moe_latent_size = None - if is_nemotron_vl(hf_config): print( "Detected Nemotron VL model from config. " diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 08486a2c08..37e0628465 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1409,19 +1409,15 @@ class LocalHessianCalibConfig(QuantizeAlgorithmConfig): ) class GPTQConfig(QuantizeAlgorithmConfig): - """The config for GPTQ lite. - - GPTQ lite is a variant of GPTQ that does not exactly follow the official GPTQ implementation. + """The config for GPTQ quantization. - GPTQ lite does not perform sequential quantization of layers. This means that the updated - activations are not used to process the next layer. + GPTQ minimizes the layer-wise quantization error by using second-order (Hessian) information + to perform blockwise weight updates that compensate for rounding loss. Layers are quantized + sequentially so that each layer's Hessian is computed from activations that already reflect + the quantization of preceding layers. The default values are taken from the official GPTQ implementation: https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L35 - - Note: This feature is currently experimental and may not translate to improved accuracy as expected. - - """ method: Literal["gptq"] = ModeloptField("gptq") @@ -1438,12 +1434,7 @@ class GPTQConfig(QuantizeAlgorithmConfig): description="""The block size for GPTQ weight update, which must be a multiple of the group_size used in the quantization.""", ) - hessian_state_path: str | None = ModeloptField( - default=None, - title="Path to the Hessian state file.", - description="""The path to the Hessian state file. If hessian path exists, we load from - hessian file instead of recomputing them.""", - ) + class SmoothQuantCalibConfig(QuantizeAlgorithmConfig): diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 343ca01910..8b40bd53d5 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -17,6 +17,7 @@ import math import os +import time import warnings from collections.abc import Callable from functools import partial @@ -1850,6 +1851,8 @@ def _blockwise_weight_update_unfused( n_cols_blk = block_end - block_start h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end] + # wblk is a scratch copy for intra-block error propagation; weight gets + # the final quantized values. Inter-block errors are propagated via addmm_ below. if col_qdq_supported: wblk = weight[:, block_start:block_end].clone() errs = torch.zeros_like(wblk) @@ -2110,7 +2113,21 @@ def gptq( block_size: int = 128, **kwargs, ): - """GPTQ quantization - a GPTQ variant. + """GPTQ quantization for a single decoder layer. + + Invoked by ``sequential_calibrate`` which walks layers one at a time so each + layer sees activations already updated by the quantization of preceding layers. + Within a layer the steps are: + + 1. ``max_calibrate`` to set amax values from the current activations. + 2. Promote eligible quantizers to ``NVFP4StaticQuantizer`` (two-level scaling). + 3. Collect per-linear-layer Hessian matrices via forward hooks. + 4. Blockwise weight updates using the inverse Hessian to compensate for + rounding error (the core GPTQ column-wise update). + + In contrast to ``gptq_lite``, which quantizes all layers in parallel using the + original (unquantized) activations, this method performs sequential calibration + and therefore produces more accurate Hessian estimates. Args: layer: A single decoder layer to quantize. @@ -2119,8 +2136,6 @@ def gptq( percdamp: Percentage of avg Hessian diagonal for damping (default: 0.01). block_size: Block size for GPTQ weight update. """ - import time - total_start = time.time() # Set weight amax and activation amax for the current layer using max_calibrate diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index 20bdb8f51d..1203c20ef7 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -55,8 +55,11 @@ def test_update_hessian(): f"Expected hessian shape ({features}, {features}), got {updated_hessian.shape}" ) - # Verify sample count is updated correctly (incremented by batch_size) - assert new_n_samples == batch_size, f"Expected n_samples={batch_size}, got {new_n_samples}" + # Verify sample count is updated correctly (incremented by total tokens = batch * seq_len) + expected_n_samples = batch_size * seq_len + assert new_n_samples == expected_n_samples, ( + f"Expected n_samples={expected_n_samples}, got {new_n_samples}" + ) # Verify hessian is not all zeros after update assert not torch.allclose(updated_hessian, torch.zeros_like(updated_hessian)), ( @@ -79,22 +82,23 @@ def test_update_hessian(): # Manual calculation: # input_flat shape: (features, batch*seq) = (2, 12), all ones - # scaled_input = sqrt(2/6) * input_flat = sqrt(1/3) * ones(2, 12) - # outer_product = scaled_input @ scaled_input.t() = (2/6) * ones(2,12) @ ones(12,2) = [[4,4], [4,4]] - # Note: The scaling factor is (2/n_samples), so with n_samples=6 and 12 tokens: (2/6)*12 = 4 - expected_hessian = torch.ones(features, features, dtype=torch.float32) * 4.0 + # n_samples = batch * seq = 12 (token count after flattening) + # scaled_input = sqrt(2/12) * ones(2, 12) + # outer_product = (2/12) * ones(2,12) @ ones(12,2) = [[2,2], [2,2]] + expected_n_samples = batch_size * seq_len # 12 tokens + expected_hessian = torch.ones(features, features, dtype=torch.float32) * 2.0 assert torch.allclose(updated_hessian, expected_hessian, atol=1e-5), ( f"Expected hessian {expected_hessian}, got {updated_hessian}" ) - assert new_n_samples == batch_size + assert new_n_samples == expected_n_samples # Test 3: Accumulated hessians - verify equivalence # Processing [6,2,2] in one step should equal processing [2,2,2] three times seq_len = 2 features = 2 - # Process in 3 steps of batch_size=2 + # Process in 3 steps of batch_size=2 (4 tokens each, 12 total) hessian_accumulated = torch.zeros(features, features, dtype=torch.float32) n_samples_accumulated = 0 @@ -111,7 +115,8 @@ def test_update_hessian(): assert torch.allclose(hessian_accumulated, expected_hessian, atol=1e-5), ( f"Accumulated hessian should match expected: expected {expected_hessian}, got {hessian_accumulated}" ) - assert n_samples_accumulated == 6, f"Expected n_samples=6, got {n_samples_accumulated}" + # 3 batches * 2 batch_size * 2 seq_len = 12 tokens + assert n_samples_accumulated == 12, f"Expected n_samples=12, got {n_samples_accumulated}" @pytest.mark.parametrize( @@ -146,14 +151,16 @@ def test_gptq_updates(block_size, dim, model_weight, expect_weight_change): hessian, n_samples = update_hessian(input, hessian, n_samples) - # Verify n_samples is update using hessian matrix - assert n_samples == input.shape[0], "n_samples should be equal to input.shape[0]" + # Verify n_samples counts total tokens (batch * seq_len) after flattening + expected_tokens = input.shape[0] * input.shape[1] # 2 * 16 = 32 + assert n_samples == expected_tokens, f"n_samples should be {expected_tokens}, got {n_samples}" # Perform another forward pass to update hessian matrix input_2 = torch.randn(3, 16, dim).to("cuda") hessian, n_samples = update_hessian(input_2, hessian, n_samples) - assert n_samples == input.shape[0] + input_2.shape[0], ( - "n_samples should be equal to input.shape[0] + input_2.shape[0]" + expected_tokens_2 = expected_tokens + input_2.shape[0] * input_2.shape[1] # 32 + 48 = 80 + assert n_samples == expected_tokens_2, ( + f"n_samples should be {expected_tokens_2}, got {n_samples}" ) hessian = hessian.to(input.device) From 364cc54d72e229c9d6be0e650c5a091bc9f550c7 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Sat, 21 Mar 2026 00:40:02 +0000 Subject: [PATCH 34/52] remove stray files Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 126 +-- modelopt/torch/quantization/__init__.py | 8 +- .../torch/quantization/metrics/__init__.py | 28 - .../quantization/metrics/activation_mse.py | 831 ------------------ .../torch/quantization/metrics/perplexity.py | 81 -- 5 files changed, 2 insertions(+), 1072 deletions(-) delete mode 100644 modelopt/torch/quantization/metrics/__init__.py delete mode 100644 modelopt/torch/quantization/metrics/activation_mse.py delete mode 100644 modelopt/torch/quantization/metrics/perplexity.py diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index d9e9902027..f9ba9784ba 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -15,7 +15,6 @@ import argparse import copy -import os import random import time import warnings @@ -63,11 +62,6 @@ ) from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration -from modelopt.torch.quantization.metrics import ( - ActivationMSELogger, - compute_perplexity, - get_wikitext2, -) from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights from modelopt.torch.quantization.utils import is_quantized from modelopt.torch.utils.dataset_utils import ( @@ -105,7 +99,6 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None: "int4_awq": mtq.INT4_AWQ_CFG, "w4a8_awq": mtq.W4A8_AWQ_BETA_CFG, "nvfp4": mtq.NVFP4_DEFAULT_CFG, - "nvfp4_wo": mtq.NVFP4_WEIGHT_ONLY_CFG, "nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG, "nvfp4_mse": mtq.NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG, "fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, @@ -116,7 +109,6 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None: "nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG, "nvfp4_omlp_only": mtq.NVFP4_OMLP_ONLY_CFG, "nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG, - "nvfp4_wo_gptq": mtq.NVFP4_WEIGHT_ONLY_GPTQ_CFG, "nvfp4_gptq": mtq.NVFP4_GPTQ_CFG, "mxfp8": mtq.MXFP8_DEFAULT_CFG, "nvfp4_local_hessian": mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG, @@ -703,12 +695,6 @@ def export_quantized( "They will be set at deployment time." ) - if getattr(args, "eval_perplexity", False) and tokenizer is not None: - seq_len = getattr(args, "eval_perplexity_seq_len", 2048) - eval_data = get_wikitext2(tokenizer, seq_len) - ppl = compute_perplexity(full_model, eval_data) - print(f"Wikitext-2 perplexity: {round(ppl, 2):.2f}") - # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization if args.vllm_fakequant_export: @@ -952,64 +938,6 @@ def quantize_main( args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model ) - # Collect original (unquantized) activations before quantization modifies the model - mse_logger = None - if getattr(args, "measure_activation_mse", False): - n_mse = getattr(args, "activation_mse_max_samples", 16) - mse_save_dir = getattr(args, "activation_mse_save_dir", None) - mse_input_path = getattr(args, "activation_mse_input_path", None) - - # Resolve MSE input data: frozen file (raw text or tokenized) or live dataloader - mse_data = None - if mse_input_path is not None: - if mse_input_path.endswith(".json"): - if os.path.isfile(mse_input_path): - print(f"Loading MSE input data from existing .json file: {mse_input_path}") - texts = ActivationMSELogger.load_raw_text(mse_input_path) - mse_data = ActivationMSELogger.tokenize_raw_text( - texts, - tokenizer, - max_length=args.calib_seq, - ) - else: - assert tokenizer is not None, ( - "--activation_mse_input_path with .json requires a tokenizer to decode" - ) - print(f"Creating MSE input data .json file: {mse_input_path}") - texts = ActivationMSELogger.materialize_raw_text( - calib_dataloader, - mse_input_path, - tokenizer=tokenizer, - max_samples=n_mse, - ) - mse_data = ActivationMSELogger.tokenize_raw_text( - texts, - tokenizer, - max_length=args.calib_seq, - ) - elif mse_input_path.endswith(".pt"): - if os.path.isfile(mse_input_path): - print(f"Loading MSE input data from existing .pt file: {mse_input_path}") - mse_data = ActivationMSELogger.load_data(mse_input_path) - else: - print(f"Creating MSE input data .pt file: {mse_input_path}") - mse_data = ActivationMSELogger.materialize_data( - calib_dataloader, - mse_input_path, - max_samples=n_mse, - ) - else: - raise ValueError( - f"--activation_mse_input_path must end with .json or .pt, got: {mse_input_path}" - ) - - if mse_data is None: - mse_data = calib_dataloader - - mse_logger = ActivationMSELogger(max_samples=n_mse, save_dir=mse_save_dir) - print(f"Collecting original (unquantized) activations for MSE over {n_mse} samples...") - mse_logger.collect(language_model, mse_data, phase="original") - if args.auto_quantize_bits: assert len(args.qformat.split(",")) > 1, ( "Auto quantization needs multiple quantization format." @@ -1102,22 +1030,6 @@ def quantize_main( first_text_speech_dataset, ) - if mse_logger is not None: - import gc - - print("Collecting quantized activations for MSE...") - mse_logger.collect(language_model, mse_data, phase="quantized") - - mse_logger.compute_mse() - print(mse_logger.summary()) - - if getattr(args, "activation_mse_save_dir", None): - mse_logger.save() - - del mse_logger, mse_data - gc.collect() - torch.cuda.empty_cache() - export_quantized( args, full_model, @@ -1342,43 +1254,7 @@ def parse_args() -> argparse.Namespace: help="Export as vLLM fake-quant checkpoint (produces vllm_fq_modelopt_state.pth " "for use with vllm_serve_fakequant.py).", ) - parser.add_argument( - "--eval_perplexity_seq_len", - type=int, - default=2048, - help="Sequence length for perplexity evaluation (default: 2048).", - ) - parser.add_argument( - "--measure_activation_mse", - action=argparse.BooleanOptionalAction, - default=False, - help="Measure per-layer activation MSE (original vs quantized) after quantization.", - ) - parser.add_argument( - "--activation_mse_max_samples", - type=int, - default=16, - help="Max calibration samples for activation MSE (default: 16).", - ) - parser.add_argument( - "--activation_mse_save_dir", - type=str, - default=None, - help="Directory to save activation MSE results. If not set, results are only printed.", - ) - parser.add_argument( - "--activation_mse_input_path", - type=str, - default=None, - help=( - "Path to frozen MSE input data. Supports two formats:\n" - " .json — raw text (cross-model reuse): if file exists, loads and re-tokenizes " - "with the current model's tokenizer; if not, decodes calibration data to text and saves.\n" - " .pt — tokenized tensors (same-tokenizer reuse): if file exists, loads directly; " - "if not, materializes from calibration data and saves." - ), - ) - + args = parser.parse_args() if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0): parser.error("--moe_calib_experts_ratio must be in the range (0.0, 1.0].") diff --git a/modelopt/torch/quantization/__init__.py b/modelopt/torch/quantization/__init__.py index d471e55823..87dbf30bb5 100644 --- a/modelopt/torch/quantization/__init__.py +++ b/modelopt/torch/quantization/__init__.py @@ -16,18 +16,12 @@ """Quantization package.""" # Initialize mode and plugins -from . import metrics, mode, plugins, utils +from . import mode, plugins, utils # Add methods to mtq namespace from .compress import * from .config import * from .conversion import * -from .metrics import ( - ActivationMSELogger, - compute_perplexity, - get_wikitext2, - measure_per_layer_activation_mse, -) from .model_quant import * from .nn.modules.quant_module import QuantModuleRegistry from .utils import update_quant_cfg_with_kv_cache_quant diff --git a/modelopt/torch/quantization/metrics/__init__.py b/modelopt/torch/quantization/metrics/__init__.py deleted file mode 100644 index a1c737c3c0..0000000000 --- a/modelopt/torch/quantization/metrics/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# mypy: ignore-errors - -"""Metrics for evaluating quantized models.""" - -from .activation_mse import ActivationMSELogger, measure_per_layer_activation_mse -from .perplexity import compute_perplexity, get_wikitext2 - -__all__ = [ - "ActivationMSELogger", - "compute_perplexity", - "get_wikitext2", - "measure_per_layer_activation_mse", -] diff --git a/modelopt/torch/quantization/metrics/activation_mse.py b/modelopt/torch/quantization/metrics/activation_mse.py deleted file mode 100644 index 1b60977ee1..0000000000 --- a/modelopt/torch/quantization/metrics/activation_mse.py +++ /dev/null @@ -1,831 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# mypy: ignore-errors -# ruff: noqa: D107, D205, PERF401, PLR0124 - -"""Per-layer activation MSE between original (unquantized) and quantized model. - -Includes the portable ``ActivationMSELogger`` class that works across codebases -(FP-Quant List[Tensor] style *and* ModelOpt DataLoader-of-dicts style). - -Ported from FP-Quant: https://github.com/IST-DASLab/FP-Quant -""" - -import fnmatch -import gc -import hashlib -import json -import os -from datetime import datetime - -import torch -import torch.nn as nn -import torch.nn.functional as F -from tqdm import tqdm - - -def _get_module(block: nn.Module, name: str) -> nn.Module: - """Get submodule from block by dotted name, e.g. 'self_attn.q_proj'.""" - obj = block - for part in name.split("."): - obj = getattr(obj, part) - return obj - - -def _get_linear_layer_names(block: nn.Module) -> list[str]: - """Collect relative names of linear layers in a transformer block (same as GPTQ).""" - names = [] - for name, layer in block.named_modules(): - if isinstance(layer, nn.Linear): - names.append(name) - return names - - -def _tensor_from_output(out) -> torch.Tensor: - """Extract a single tensor from layer output (handle tuple return).""" - if isinstance(out, torch.Tensor): - return out.detach() - return out[0].detach() - - -def _discover_layer_keys(blocks, layer_names, num_blocks): - """Build list of valid layer keys.""" - keys = [] - for i in range(num_blocks): - for name in layer_names: - try: - _get_module(blocks[i], name) - except AttributeError: - continue - keys.append(f"model.layers.{i}.{name}") - return keys - - -def _collect_outputs( - model: nn.Module, - blocks: nn.ModuleList, - layer_names: list[str], - layer_keys: list[str], - calibration_data: list[torch.Tensor], - device: torch.device | str, - num_blocks: int, - desc: str, -) -> dict[str, list[torch.Tensor]]: - """Run model on calibration data, capture per-layer outputs (moved to CPU).""" - captured: dict[str, torch.Tensor] = {} - saved: dict[str, list[torch.Tensor]] = {k: [] for k in layer_keys} - - def make_hook(key: str): - def hook(_module: nn.Module, _input: tuple, output) -> None: - captured[key] = _tensor_from_output(output).cpu() - - return hook - - hooks = [] - for i in range(num_blocks): - for name in layer_names: - key = f"model.layers.{i}.{name}" - if key not in saved: - continue - try: - mod = _get_module(blocks[i], name) - except AttributeError: - continue - hooks.append(mod.register_forward_hook(make_hook(key))) - - try: - for sample in tqdm(calibration_data, desc=desc, leave=False): - inp = sample.unsqueeze(0) if sample.dim() == 1 else sample - inp = inp.to(device) - captured.clear() - with torch.no_grad(): - _ = model(inp) - for key in layer_keys: - if key in captured: - saved[key].append(captured[key]) - finally: - for h in hooks: - h.remove() - return saved - - -@torch.no_grad() -def measure_per_layer_activation_mse( - model_orig: nn.Module, - model_quant: nn.Module, - calibration_data: list[torch.Tensor], - device: torch.device | str, - log_wandb: bool = False, - max_samples: int | None = None, -) -> dict[str, float]: - """Measure per-linear-layer MSE between activations of the original (unquantized) - model and the quantized model on the same calibration data. - - Runs each model on GPU one at a time to avoid OOM. - Returns a dict mapping layer key (e.g. "model.layers.0.self_attn.q_proj") to MSE. - """ - if max_samples is not None and max_samples > 0: - calibration_data = calibration_data[:max_samples] - - blocks_quant = model_quant.model.layers - blocks_orig = model_orig.model.layers - num_blocks = len(blocks_quant) - assert len(blocks_orig) == num_blocks - - layer_names = _get_linear_layer_names(blocks_quant[0]) - layer_keys = _discover_layer_keys(blocks_quant, layer_names, num_blocks) - - # --- Phase 1: run quantized model on GPU, save outputs to CPU --- - print(" Phase 1/2: collecting quantized model outputs...") - model_quant.to(device) - quant_outputs = _collect_outputs( - model_quant, - blocks_quant, - layer_names, - layer_keys, - calibration_data, - device, - num_blocks, - desc="Activation MSE (quant)", - ) - # Free GPU for original model - model_quant.cpu() - gc.collect() - torch.cuda.empty_cache() - - # --- Phase 2: run original model on GPU, compute MSE vs stored quant --- - print(" Phase 2/2: collecting original model outputs and computing MSE...") - model_orig.to(device) - - # Instead of storing orig outputs, compute MSE on the fly per sample - sum_sq: dict[str, float] = dict.fromkeys(layer_keys, 0.0) - count: dict[str, int] = dict.fromkeys(layer_keys, 0) - - captured: dict[str, torch.Tensor] = {} - - def make_hook(key: str): - def hook(_module: nn.Module, _input: tuple, output) -> None: - captured[key] = _tensor_from_output(output).cpu() - - return hook - - hooks = [] - for i in range(num_blocks): - for name in layer_names: - key = f"model.layers.{i}.{name}" - if key not in sum_sq: - continue - try: - mod = _get_module(blocks_orig[i], name) - except AttributeError: - continue - hooks.append(mod.register_forward_hook(make_hook(key))) - - try: - for sample_idx, sample in enumerate( - tqdm(calibration_data, desc="Activation MSE (orig)", leave=False) - ): - inp = sample.unsqueeze(0) if sample.dim() == 1 else sample - inp = inp.to(device) - captured.clear() - _ = model_orig(inp) - for key in layer_keys: - if key not in captured: - continue - if sample_idx >= len(quant_outputs.get(key, [])): - continue - o = captured[key].float() - q = quant_outputs[key][sample_idx].float() - if o.shape != q.shape: - continue - sum_sq[key] += F.mse_loss(o, q, reduction="sum").item() - count[key] += o.numel() - finally: - for h in hooks: - h.remove() - - # Free original model from GPU - model_orig.cpu() - gc.collect() - torch.cuda.empty_cache() - - # Move quantized model back to GPU for downstream usage - model_quant.to(device) - - mse = { - key: (sum_sq[key] / count[key]) if count[key] > 0 else float("nan") for key in layer_keys - } - - if log_wandb: - try: - import wandb - - for key, val in mse.items(): - if val == val: # skip nan - wandb.log({f"activation_mse/{key}": val}) - except ImportError: - pass - - return mse - - -# --------------------------------------------------------------------------- -# Portable ActivationMSELogger class -# --------------------------------------------------------------------------- - - -def _matches_filter(name: str, layer_filter: str | None) -> bool: - """Check if a layer name matches the optional filter pattern (fnmatch-style).""" - if layer_filter is None: - return True - return fnmatch.fnmatch(name, layer_filter) - - -def _portable_discover_target_layers( - model: nn.Module, - layer_filter: str | None = None, -) -> dict[str, nn.Module]: - """Discover linear layers in decoder blocks with a portable fallback chain. - - Strategy: - 1. Try modelopt's ``get_decoder_layers`` (available inside ModelOpt). - 2. Try common HuggingFace attribute paths (``model.model.layers``, etc.). - 3. Fall back to scanning **all** ``nn.Linear`` in ``model.named_modules()``. - - Within each set of decoder blocks the function collects every ``nn.Linear`` - sub-module and optionally filters by *layer_filter* (fnmatch pattern). - """ - decoder_layers = None - - # 1. Try modelopt helper - try: - from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector - - decoder_layers = LayerActivationCollector.get_decoder_layers(model) - except Exception: - pass - - # 2. Try common HF / other patterns - if decoder_layers is None: - for attr_chain in ( - ("model", "layers"), - ("decoder", "layers"), - ("transformer", "h"), - ("backbone", "layers"), - ): - obj = model - try: - for attr in attr_chain: - obj = getattr(obj, attr) - if isinstance(obj, nn.ModuleList): - decoder_layers = obj - break - except AttributeError: - continue - - targets: dict[str, nn.Module] = {} - - if decoder_layers is not None: - module_to_name: dict[int, str] = {id(m): n for n, m in model.named_modules()} - for block in decoder_layers: - block_name = module_to_name.get(id(block), "") - for sub_name, sub_mod in block.named_modules(): - if isinstance(sub_mod, nn.Linear): - full_name = f"{block_name}.{sub_name}" if block_name else sub_name - if _matches_filter(full_name, layer_filter): - targets[full_name] = sub_mod - else: - # 3. Fallback: all linear layers - for name, module in model.named_modules(): - if isinstance(module, nn.Linear): - if _matches_filter(name, layer_filter): - targets[name] = module - - return targets - - -class ActivationMSELogger: - """Portable activation MSE logger for comparing original vs quantized models. - - Works with both: - - - ``List[Tensor]`` data (**FP-Quant** style): each element is ``[1, seq_len]`` - or ``[B, seq_len]``, consumed via ``model(tensor)``. - - ``DataLoader`` / ``Iterable`` yielding dicts (**ModelOpt** style): - ``{"input_ids": tensor, ...}``, consumed via ``model(**batch)``. - - Guarantees same samples are used for both phases via SHA-256 hashing of - input tensors. Supports saving / loading all activations to disk for - later cross-codebase comparison. - - Example (FP-Quant -- List[Tensor]):: - - mse_logger = ActivationMSELogger(max_samples=16, save_dir="./mse_logs") - mse_logger.collect(model_orig, calibration_data, phase="original") - mse_logger.collect(model_quant, calibration_data, phase="quantized") - results = mse_logger.compute_mse() - print(mse_logger.summary()) - mse_logger.save() - - Example (ModelOpt -- DataLoader with dict batches):: - - mse_logger = ActivationMSELogger(max_samples=16, save_dir="./mse_logs") - mse_logger.collect(model, dataloader, phase="original") - model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - mse_logger.collect(model, dataloader, phase="quantized") - results = mse_logger.compute_mse() - print(mse_logger.summary()) - mse_logger.save() - """ - - def __init__( - self, - max_samples: int = 16, - layer_filter: str | None = None, - save_dir: str | None = None, - ): - self.max_samples = max_samples - self.layer_filter = layer_filter - self.save_dir = save_dir - - # Per-phase state - self.original_activations: dict[str, list[torch.Tensor]] = {} - self.quantized_activations: dict[str, list[torch.Tensor]] = {} - self.input_hashes: list[str] = [] # hashes for "original" phase - self.quant_input_hashes: list[str] = [] # hashes for "quantized" phase - - # Computed after both phases - self.mse_results: dict[str, float] | None = None - - # ------------------------------------------------------------------ - # Public API - # ------------------------------------------------------------------ - - @torch.no_grad() - def collect( - self, - model: nn.Module, - data, - phase: str, - target_modules: dict[str, nn.Module] | None = None, - ) -> None: - """Collect per-linear-layer output activations for a given phase. - - Args: - model: The model to run (original or quantized). - data: An iterable of batches. Each batch can be: - - - ``torch.Tensor`` with shape ``[B, seq_len]`` (FP-Quant style). - - ``dict`` with at least an ``"input_ids"`` key (ModelOpt style). - - ``list`` / ``tuple`` of tensors. - phase: ``"original"`` or ``"quantized"``. - target_modules: Optional explicit mapping of ``{name: nn.Module}`` - to attach hooks to. If *None*, layers are auto-discovered - via decoder-block scanning. - """ - if phase not in ("original", "quantized"): - raise ValueError(f"phase must be 'original' or 'quantized', got {phase!r}") - - was_training = model.training - model.eval() - - # ----- layer discovery ----- - targets = ( - target_modules - if target_modules is not None - else (_portable_discover_target_layers(model, self.layer_filter)) - ) - if not targets: - raise ValueError( - "No linear layers found. Provide target_modules explicitly or " - f"check layer_filter={self.layer_filter!r}." - ) - - print( - f"[ActivationMSELogger] Phase '{phase}': hooking {len(targets)} layers, " - f"max_samples={self.max_samples}" - ) - - # ----- storage ----- - saved: dict[str, list[torch.Tensor]] = {name: [] for name in targets} - captured: dict[str, torch.Tensor] = {} - hashes: list[str] = [] - - def _make_hook(key: str): - def hook(_module: nn.Module, _input, output) -> None: - captured[key] = _tensor_from_output(output).cpu() - - return hook - - hooks = [] - for name, module in targets.items(): - hooks.append(module.register_forward_hook(_make_hook(name))) - - try: - n_batches = 0 - for batch in tqdm(data, desc=f"Collecting ({phase})", leave=False): - if self.max_samples is not None and n_batches >= self.max_samples: - break - - captured.clear() - self._run_batch(model, batch) - - for name in targets: - if name in captured: - saved[name].append(captured[name]) - - hashes.append(self._hash_batch(batch)) - n_batches += 1 - finally: - for h in hooks: - h.remove() - - model.train(was_training) - - # ----- store results on self ----- - if phase == "original": - self.original_activations = saved - self.input_hashes = hashes - else: - self.quantized_activations = saved - self.quant_input_hashes = hashes - # Verify sample consistency - if self.input_hashes: - self._verify_hashes() - - # Invalidate any previous MSE since we have new activations - self.mse_results = None - - print(f"[ActivationMSELogger] Collected {n_batches} batches for phase '{phase}'") - - def compute_mse(self) -> dict[str, float]: - """Compute per-layer MSE between original and quantized activations. - - Returns: - Dict mapping layer name to its MSE value. - - Raises: - ValueError: If either phase has not been collected yet. - """ - if not self.original_activations: - raise ValueError( - "No original activations collected. Call collect(..., phase='original') first." - ) - if not self.quantized_activations: - raise ValueError( - "No quantized activations collected. Call collect(..., phase='quantized') first." - ) - - common_keys = sorted( - set(self.original_activations.keys()) & set(self.quantized_activations.keys()) - ) - if not common_keys: - raise ValueError( - "No matching layer names between original and quantized activations. " - "Ensure the same model architecture / layer_filter is used for both phases." - ) - - orig_only = set(self.original_activations.keys()) - set(self.quantized_activations.keys()) - quant_only = set(self.quantized_activations.keys()) - set(self.original_activations.keys()) - if orig_only: - print( - f"[ActivationMSELogger] Warning: {len(orig_only)} layers only in original (skipped)" - ) - if quant_only: - print( - f"[ActivationMSELogger] Warning: {len(quant_only)} layers only in quantized (skipped)" - ) - - sum_sq: dict[str, float] = dict.fromkeys(common_keys, 0.0) - count: dict[str, int] = dict.fromkeys(common_keys, 0) - - for name in common_keys: - orig_list = self.original_activations[name] - quant_list = self.quantized_activations[name] - n = min(len(orig_list), len(quant_list)) - for i in range(n): - o = orig_list[i].float() - q = quant_list[i].float() - if o.shape != q.shape: - print( - f"[ActivationMSELogger] Warning: shape mismatch for {name} " - f"batch {i}: {o.shape} vs {q.shape}, skipping" - ) - continue - sum_sq[name] += F.mse_loss(o, q, reduction="sum").item() - count[name] += o.numel() - - self.mse_results = { - key: (sum_sq[key] / count[key]) if count[key] > 0 else float("nan") - for key in common_keys - } - return self.mse_results - - def save(self, path: str | None = None) -> str: - """Save all state (activations, hashes, MSE) to disk via ``torch.save``. - - Args: - path: Explicit file path. If *None*, a timestamped file is created - inside ``self.save_dir`` (which must be set). - - Returns: - The path where the file was saved. - """ - if path is None: - if self.save_dir is None: - raise ValueError("Provide a path or set save_dir in the constructor.") - os.makedirs(self.save_dir, exist_ok=True) - ts = datetime.now().strftime("%Y%m%d_%H%M%S") - path = os.path.join(self.save_dir, f"activation_mse_{ts}.pt") - - payload = { - "max_samples": self.max_samples, - "layer_filter": self.layer_filter, - "input_hashes": self.input_hashes, - "quant_input_hashes": self.quant_input_hashes, - "original_activations": self.original_activations, - "quantized_activations": self.quantized_activations, - "mse": self.mse_results, - } - torch.save(payload, path) - print(f"[ActivationMSELogger] Saved to {path}") - return path - - @classmethod - def load(cls, path: str) -> "ActivationMSELogger": - """Load a previously saved ``ActivationMSELogger`` from disk. - - Args: - path: Path to the ``.pt`` file created by :meth:`save`. - - Returns: - A new ``ActivationMSELogger`` instance with restored state. - """ - payload = torch.load(path, map_location="cpu", weights_only=False) - logger = cls( - max_samples=payload.get("max_samples", 16), - layer_filter=payload.get("layer_filter"), - ) - logger.original_activations = payload.get("original_activations", {}) - logger.quantized_activations = payload.get("quantized_activations", {}) - logger.input_hashes = payload.get("input_hashes", []) - logger.quant_input_hashes = payload.get("quant_input_hashes", []) - logger.mse_results = payload.get("mse") - print(f"[ActivationMSELogger] Loaded from {path}") - return logger - - def summary(self) -> str: - """Return a formatted string summarising per-layer MSE results. - - Computes MSE first if not already done. - """ - if self.mse_results is None: - self.compute_mse() - assert self.mse_results is not None - - lines = ["Per-layer activation MSE (original vs quantized):"] - for key in sorted(self.mse_results.keys()): - lines.append(f" {key}: {self.mse_results[key]:.6e}") - return "\n".join(lines) - - # ------------------------------------------------------------------ - # Pre-materialized MSE data (cross-run / cross-codebase safety) - # ------------------------------------------------------------------ - - @staticmethod - def materialize_data( - data, - path: str, - max_samples: int | None = None, - ) -> list[torch.Tensor]: - """Freeze the first *max_samples* batches from *data* into a ``.pt`` file. - - Each batch (``dict``, ``Tensor``, or ``list/tuple``) is normalised to a - single ``input_ids`` CPU tensor before saving. The resulting file is a - plain ``List[Tensor]`` that can be loaded in **any** codebase and passed - straight to :meth:`collect`. - - If *path* already exists it is **not** overwritten -- call - :meth:`load_data` instead. - - Args: - data: Iterable of batches (DataLoader, List[Tensor], etc.). - path: Destination ``.pt`` file path. - max_samples: How many batches to keep. ``None`` means all. - - Returns: - The materialised list of CPU tensors (same object that was saved). - """ - samples: list[torch.Tensor] = [] - for batch in data: - if max_samples is not None and len(samples) >= max_samples: - break - if isinstance(batch, dict): - t = batch.get("input_ids", next(iter(batch.values()))) - elif isinstance(batch, torch.Tensor): - t = batch - elif isinstance(batch, (list, tuple)): - t = batch[0] - else: - raise TypeError(f"Unsupported batch type: {type(batch)}") - samples.append(t.cpu()) - - os.makedirs(os.path.dirname(path) or ".", exist_ok=True) - torch.save(samples, path) - print(f"[ActivationMSELogger] Materialised {len(samples)} MSE input samples -> {path}") - return samples - - @staticmethod - def load_data(path: str) -> list[torch.Tensor]: - """Load a previously materialised MSE input set. - - Args: - path: Path to the ``.pt`` file created by :meth:`materialize_data`. - - Returns: - ``List[Tensor]`` of input batches (on CPU). - """ - samples = torch.load(path, map_location="cpu", weights_only=True) - print(f"[ActivationMSELogger] Loaded {len(samples)} MSE input samples from {path}") - return samples - - # ------------------------------------------------------------------ - # Raw-text materialization (cross-model / cross-tokenizer reuse) - # ------------------------------------------------------------------ - - @staticmethod - def materialize_raw_text( - data, - path: str, - tokenizer=None, - max_samples: int | None = None, - ) -> list[str]: - """Save raw text strings to a JSON file for cross-model reuse. - - Extracts text from batches by decoding ``input_ids`` with the provided - *tokenizer*. The saved JSON file can be loaded by any model regardless - of its vocabulary and re-tokenized via :meth:`tokenize_raw_text`. - - Args: - data: Iterable of batches (DataLoader, ``List[Tensor]``, etc.). - path: Destination ``.json`` file path. - tokenizer: A HuggingFace tokenizer with a ``decode`` method. - Required to convert token IDs back to text. - max_samples: How many batches to keep. ``None`` means all. - - Returns: - The list of decoded text strings (same content that was saved). - """ - if tokenizer is None: - raise ValueError( - "tokenizer is required for materialize_raw_text to decode input_ids back to text." - ) - - texts: list[str] = [] - for batch in data: - if max_samples is not None and len(texts) >= max_samples: - break - if isinstance(batch, dict): - t = batch.get("input_ids", next(iter(batch.values()))) - elif isinstance(batch, torch.Tensor): - t = batch - elif isinstance(batch, (list, tuple)): - t = batch[0] - else: - raise TypeError(f"Unsupported batch type: {type(batch)}") - - if t.dim() == 1: - t = t.unsqueeze(0) - for row in t: - if max_samples is not None and len(texts) >= max_samples: - break - texts.append(tokenizer.decode(row, skip_special_tokens=True)) - - os.makedirs(os.path.dirname(path) or ".", exist_ok=True) - payload = {"texts": texts, "max_samples": len(texts)} - with open(path, "w", encoding="utf-8") as f: - json.dump(payload, f, ensure_ascii=False, indent=2) - - print(f"[ActivationMSELogger] Saved {len(texts)} raw text samples -> {path}") - return texts - - @staticmethod - def load_raw_text(path: str) -> list[str]: - """Load raw text strings from a JSON file created by :meth:`materialize_raw_text`. - - Args: - path: Path to the ``.json`` file. - - Returns: - List of raw text strings. - """ - with open(path, encoding="utf-8") as f: - payload = json.load(f) - texts = payload["texts"] - print(f"[ActivationMSELogger] Loaded {len(texts)} raw text samples from {path}") - return texts - - @staticmethod - def tokenize_raw_text( - texts: list[str], - tokenizer, - max_length: int = 2048, - ) -> list[torch.Tensor]: - """Tokenize raw text strings into a ``List[Tensor]`` for :meth:`collect`. - - Each string is independently tokenized and truncated to *max_length*. - Returns one ``[1, seq_len]`` tensor per string — the same format - expected by :meth:`collect` and :func:`compute_perplexity`. - - Args: - texts: List of raw text strings (from :meth:`load_raw_text`). - tokenizer: A HuggingFace tokenizer. - max_length: Maximum token length per sample (default: 2048). - - Returns: - ``List[Tensor]`` of tokenized inputs on CPU. - """ - samples: list[torch.Tensor] = [] - for text in texts: - encoded = tokenizer( - text, - return_tensors="pt", - max_length=max_length, - truncation=True, - add_special_tokens=False, - ) - samples.append(encoded.input_ids.cpu()) - print(f"[ActivationMSELogger] Tokenized {len(samples)} samples (max_length={max_length})") - return samples - - # ------------------------------------------------------------------ - # Static / private helpers - # ------------------------------------------------------------------ - - @staticmethod - def _run_batch(model: nn.Module, batch) -> None: - """Run a single batch through the model (handles Tensor, dict, list/tuple). - - Automatically moves inputs to the model's device so that CPU-stored - materialized data works transparently with a CUDA model. - """ - device = next(model.parameters()).device - if isinstance(batch, dict): - batch = { - k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() - } - model(**batch) - elif isinstance(batch, torch.Tensor): - model(batch.to(device)) - elif isinstance(batch, (list, tuple)): - batch = tuple(t.to(device) if isinstance(t, torch.Tensor) else t for t in batch) - model(*batch) - else: - raise TypeError(f"Unsupported batch type: {type(batch)}") - - @staticmethod - def _hash_batch(batch) -> str: - """Compute SHA-256 hash of the primary input tensor in *batch*. - - - ``dict`` -> hashes ``batch["input_ids"]`` (falls back to first value). - - ``Tensor`` -> hashes the tensor directly. - - ``list/tuple`` -> hashes the first element. - """ - if isinstance(batch, dict): - t = batch.get("input_ids", next(iter(batch.values()))) - elif isinstance(batch, torch.Tensor): - t = batch - elif isinstance(batch, (list, tuple)): - t = batch[0] if batch else None - else: - return "" - - if t is None or not isinstance(t, torch.Tensor): - return "" - return hashlib.sha256(t.cpu().contiguous().numpy().tobytes()).hexdigest() - - def _verify_hashes(self) -> None: - """Compare input hashes between original and quantized phases.""" - n = min(len(self.input_hashes), len(self.quant_input_hashes)) - mismatches = sum(1 for i in range(n) if self.input_hashes[i] != self.quant_input_hashes[i]) - if mismatches: - print( - f"[ActivationMSELogger] WARNING: {mismatches}/{n} batches have " - f"different input hashes between original and quantized phases. " - f"The same data may not have been used for both phases!" - ) - else: - print(f"[ActivationMSELogger] Input hash verification passed ({n}/{n} match)") diff --git a/modelopt/torch/quantization/metrics/perplexity.py b/modelopt/torch/quantization/metrics/perplexity.py deleted file mode 100644 index 2b592914ae..0000000000 --- a/modelopt/torch/quantization/metrics/perplexity.py +++ /dev/null @@ -1,81 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# mypy: ignore-errors -# ruff: noqa: D103, PERF401 - -"""Perplexity evaluation for language models. - -Ported from FP-Quant: https://github.com/IST-DASLab/FP-Quant -""" - -import torch -import torch.nn.functional as F -from tqdm import trange - - -@torch.no_grad() -def compute_perplexity(model, data, batch_size: int = 1): - num_samples = len(data) - device = next(model.parameters()).device - # Running estimate of negative log-likelihood - nll_running = 0 - # Number of tokens processed to far - tokens_processed = 0 - # Loop through each batch - for i in trange(0, num_samples, batch_size, desc="Computing perplexity", leave=False): - j = min(i + batch_size, num_samples) - inputs = torch.cat(data[i:j]).to(device) - # Forward pass through the model - lm_logits = model(inputs).logits - # Shift logits and labels for next token prediction - shift_logits = lm_logits[:, :-1, :].contiguous() - shift_labels = inputs[:, 1:] - # Compute loss - loss = F.cross_entropy( - shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1) - ) - # Calculate negative log likelihood - a = shift_labels.numel() / (tokens_processed + shift_labels.numel()) - b = tokens_processed / (tokens_processed + shift_labels.numel()) - nll_running = a * loss + b * nll_running - # Update number of processed tokens - tokens_processed += shift_labels.numel() - # Compute perplexity - ppl = nll_running.exp().item() - return ppl - - -def get_wikitext2(tokenizer, sequence_length: int): - """Load WikiText-2 test set as a list of tokenized sequences for perplexity evaluation. - - Args: - tokenizer: HuggingFace tokenizer. - sequence_length: Length of each evaluation sequence. - - Returns: - List of tensors, each of shape ``[1, sequence_length]``. - """ - from datasets import load_dataset - - test_dataset_raw = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") - test_dataset_tok = tokenizer( - "\n\n".join(test_dataset_raw["text"]), return_tensors="pt" - ).input_ids - num_test_sequences = test_dataset_tok.numel() // sequence_length - test_loader = [] - for i in range(num_test_sequences): - test_loader.append(test_dataset_tok[:, i * sequence_length : (i + 1) * sequence_length]) - return test_loader From 25744756d1bd4f455c4dac3007fb9154aaeb646c Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Sun, 22 Mar 2026 02:32:16 +0000 Subject: [PATCH 35/52] refactor Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 27 +- modelopt/torch/quantization/mode.py | 14 - modelopt/torch/quantization/model_calib.py | 670 +++++------------- .../quantization/triton/gptq_fused_kernel.py | 189 ----- tests/gpu/torch/quantization/test_gptq.py | 104 +-- 5 files changed, 182 insertions(+), 822 deletions(-) delete mode 100644 modelopt/torch/quantization/triton/gptq_fused_kernel.py diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 37e0628465..f993461e65 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1567,23 +1567,19 @@ class SVDQuantConfig(QuantizeAlgorithmConfig): ) -class GPTQLiteConfig(QuantizeAlgorithmConfig): - """The config for GPTQ lite. - - GPTQ lite is a variant of GPTQ that does not exactly follow the official GPTQ implementation. +class GPTQConfig(QuantizeAlgorithmConfig): + """The config for GPTQ quantization. - GPTQ lite does not perform sequential quantization of layers. This means that the updated - activations are not used to process the next layer. + GPTQ minimizes the layer-wise quantization error by using second-order (Hessian) information + to perform blockwise weight updates that compensate for rounding loss. Layers are quantized + sequentially so that each layer's Hessian is computed from activations that already reflect + the quantization of preceding layers. The default values are taken from the official GPTQ implementation: https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L35 - - Note: This feature is currently experimental and may not translate to improved accuracy as expected. - - """ - method: Literal["gptq_lite"] = ModeloptField("gptq_lite") + method: Literal["gptq"] = ModeloptField("gptq") percdamp: float | None = ModeloptField( default=0.01, gt=0.0, @@ -1597,15 +1593,6 @@ class GPTQLiteConfig(QuantizeAlgorithmConfig): description="""The block size for GPTQ weight update, which must be a multiple of the group_size used in the quantization.""", ) - hessian_state_path: str | None = ModeloptField( - default=None, - title="Path to the Hessian state file.", - description="""The path to the Hessian state file. If hessian path exists, we load from - hessian file instead of recomputing them.""", - ) - - -QuantizeQuantCfgType = list[QuantizerCfgEntry] QuantizeQuantCfgType = dict[ str | Callable, diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index df48c72c29..63b3a7c913 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -38,7 +38,6 @@ AWQLiteCalibConfig, CompressConfig, GPTQConfig, - GPTQLiteConfig, LocalHessianCalibConfig, MaxCalibConfig, MseCalibConfig, @@ -61,7 +60,6 @@ from .model_calib import ( awq, gptq, - gptq_lite, local_hessian_calibrate, max_calibrate, mse_calibrate, @@ -494,18 +492,6 @@ def restore(self) -> RestoreEntrypoint: return restore_svdquant_model -@CalibrateModeRegistry.register_mode -class GPTQLiteModeDescriptor(BaseCalibrateModeDescriptor): - """Mode for GPTQ calibration algorithm.""" - - @property - def config_class(self) -> type[QuantizeAlgorithmConfig]: - """Specifies the config class for the mode.""" - return GPTQLiteConfig - - _calib_func = gptq_lite - - @CalibrateModeRegistry.register_mode class GPTQModeDescriptor(BaseCalibrateModeDescriptor): """Mode for GPTQ calibration algorithm.""" diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 8b40bd53d5..15ab2bec65 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -16,7 +16,6 @@ """Calibration utilities.""" import math -import os import time import warnings from collections.abc import Callable @@ -1571,461 +1570,187 @@ def postprocess(module, name): max_calibrate(model, forward_loop) -def _print_relative_mse_error( - q: torch.Tensor, - w: torch.Tensor, - h: torch.Tensor, - module_name: str, - n_samples: int | None = None, -): - """Print relative mean squared error between quantized and original weights. - - Computes the Hessian-weighted relative MSE between quantized and original weights, - providing a measure of quantization quality. This metric is adapted from the GPTQ - repository. - - Args: - q (torch.Tensor): Quantized weight tensor - w (torch.Tensor): Original weight tensor - h (torch.Tensor): Hessian matrix used for weighting the error - module_name (str): Name of the module for logging purposes - n_samples (int | None): Number of Hessian samples (batches) used for this layer - Note: - Implementation adapted from the GPTQ repository: - https://github.com/IST-DASLab/FP-Quant - """ - delta = q - w - mse = (delta).mm(h).mul(delta).mean() / (w.mm(h).mul(w).mean() + 1e-6) - suffix = f", n_hessian_samples: {n_samples}" if n_samples is not None else "" - print_rank_0(f"[{module_name}] Relative MSE error: {mse.item():.2e}{suffix}") - - -def update_hessian(input, hessian, n_samples): - """Update hessian matrix with new input samples using incremental formula. - - Args: - input: Input tensor (batch_size, ..., features) - hessian: Current Hessian matrix to update in-place - n_samples: Number of samples already processed - Returns: - Tuple of (updated_hessian, new_sample_count) - """ - # Flatten to 2D (total_tokens, features) first, so batch_size counts tokens - input_flat = input.reshape(-1, input.shape[-1]).t().float() - batch_size = input_flat.shape[1] - - # Incremental averaging: scale down old hessian - hessian *= n_samples / (n_samples + batch_size) - n_samples += batch_size - - # Compute outer product: H += (2/n_samples) * X @ X^T - scaled_input = math.sqrt(2 / n_samples) * input_flat - hessian.add_((scaled_input @ scaled_input.t()).to(hessian.device)) - - return hessian, n_samples - - -def prepare_hessian_inverse(h, weight, percdamp): - """Prepare inverse Hessian with dead neuron handling and damping. - - Args: - h: Hessian matrix to update - weight: Weight tensor to prepare Hessian for - percdamp: Damping percentage for Hessian diagonal - Returns: - h_inv: Inverse Hessian matrix - Implementation adapted from the FP-Quant repository: - https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L200 - """ - h = h.clone() - # Handle dead neurons (zero weight columns) - # Get columns with all zeros in weight - zero_cols = torch.nonzero(weight.eq(0).all(dim=0)).unsqueeze(-1) - - # Zero out entire rows and columns in Hessian for dead neurons - h[zero_cols, :] = 0 - h[:, zero_cols] = 0 - h[zero_cols, zero_cols] = 1 - - # Add damping to diagonal - damp = percdamp * torch.mean(torch.diag(h)) - diag_indices = torch.arange(h.shape[0], device=h.device) - h[diag_indices, diag_indices] += damp - - try: - h = torch.cholesky_inverse(torch.linalg.cholesky(h)) - h_inv = torch.linalg.cholesky(h, upper=True) - except (RuntimeError, torch.linalg.LinAlgError): - print_rank_0("Warning: Hessian is not positive definite, using identity matrix") - h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype) - return h_inv - - -def _build_column_qdq(quantizer, weight_shape): - """Build a fast column-wise quantize-dequantize function for integer quantizers. +class GPTQHandle: + """Encapsulates per-module GPTQ state and operations. - Instead of calling the full TensorQuantizer on the entire weight matrix (which - quantizes all elements) and extracting one column, this returns a closure that - quantizes only a single column using the quantizer's pre-computed amax/scales. + Owns the Hessian, patches the forward during collection, and contains + the blockwise weight-update logic. - Since max_calibrate fixes the amax before GPTQ weight updates, quantizing a - single column with the same fixed scale gives bit-identical results to - quantizing the full matrix and extracting that column. + Instance attributes set during ``__init__``: + module, name, hessian, n_samples - Args: - quantizer: The weight TensorQuantizer (already calibrated). - weight_shape: Shape of the weight tensor (out_features, in_features). - - Returns: - Tuple of (column_qdq_fn, supported) where: - - column_qdq_fn(column, col_idx) -> qdq_column (if supported) - - supported: True if column-wise qdq is available, False to fall back. + Instance attributes set during ``quantize``: + weight: float working copy of module weights (mutated in-place by update methods) + h_inv: upper-triangular Cholesky factor of the damped inverse Hessian """ - # Unsupported: NVFP4 (two-level FP4 scaling), FP quantization (num_bits is a tuple) - if isinstance(quantizer, NVFP4StaticQuantizer): - return None, False - if isinstance(quantizer._num_bits, tuple): - return None, False - - # Unsupported: pre_quant_scale (SmoothQuant) or rotation transforms mix columns - if getattr(quantizer, "pre_quant_scale", None) is not None: - return None, False - if getattr(quantizer, "rotate_is_enabled", False): - return None, False - - # Need calibrated amax - if not hasattr(quantizer, "_amax") or quantizer._amax is None: - return None, False - - num_bits = quantizer._num_bits - unsigned = getattr(quantizer, "_unsigned", False) - narrow_range = getattr(quantizer, "_narrow_range", False) - max_bound = (2 ** (num_bits - 1 + int(unsigned))) - 1 - min_bound = -max_bound + int(narrow_range) - - amax = quantizer._amax.float() - out_features, in_features = weight_shape - - # Determine quantization geometry from block_sizes - block_sizes = quantizer.block_sizes - group_size = None - if block_sizes is not None: - # Skip dynamic block quantization - if block_sizes.get("type", "static") == "dynamic": - return None, False - group_size = block_sizes.get(-1, None) or block_sizes.get(len(weight_shape) - 1, None) - - if group_size is not None and group_size > 0: - # Per-group block quantization along last dim. - # After _setup_for_blockquant, weight is reshaped to (-1, group_size) with axis=(0,). - # amax shape: (out_features * n_groups, 1) where n_groups = in_features // group_size. - if in_features % group_size != 0: - return None, False # Padding case — fall back - - n_groups = in_features // group_size - - try: - # Reshape amax to (out_features, n_groups) for O(1) group lookup - amax_2d = amax.reshape(out_features, n_groups) - except RuntimeError: - return None, False - - def _column_qdq_group( - col, col_idx, _a=amax_2d, _mx=max_bound, _mn=min_bound, _gs=group_size - ): - col_scale = _mx / _a[:, col_idx // _gs].clamp(min=1e-12) - return torch.clamp(torch.round(col * col_scale), _mn, _mx) / col_scale - - return _column_qdq_group, True - - # Per-channel (axis != None) or per-tensor (axis == None) - axis = quantizer.axis - if axis is not None: - # Per-channel: amax has shape (out_features, 1) or similar - col_scale = max_bound / amax.reshape(-1).clamp(min=1e-12) - - def _column_qdq_channel(col, col_idx, _s=col_scale, _mx=max_bound, _mn=min_bound): - return torch.clamp(torch.round(col * _s), _mn, _mx) / _s - - return _column_qdq_channel, True - # Per-tensor: single scalar scale - scalar_scale = max_bound / amax.clamp(min=1e-12).item() + CACHE_NAME = "_forward_no_gptq_hessian" + + def __init__(self, module, name, offload_to_cpu=False): + self.module = module + self.name = name + in_features = module.weight.shape[-1] + device = module.weight.device + if offload_to_cpu and get_used_gpu_mem_fraction(device) > 0.65: + device = "cpu" + self.hessian = torch.zeros(in_features, in_features, dtype=torch.float32, device=device) + self.n_samples = 0 + # Set by quantize(); listed here for documentation. + self.weight: torch.Tensor | None = None + self.h_inv: torch.Tensor | None = None + + def setup(self): + """Patch the module's forward to accumulate Hessian during the collection pass.""" + gptq_handle = self - def _column_qdq_tensor(col, col_idx, _s=scalar_scale, _mx=max_bound, _mn=min_bound): - return torch.clamp(torch.round(col * _s), _mn, _mx) / _s - - return _column_qdq_tensor, True - - -def _can_use_fused_gptq(quantizer) -> bool: - """Check whether the fused Triton GPTQ kernel can be used for *quantizer*.""" - if not isinstance(quantizer, NVFP4StaticQuantizer): - return False - if not hasattr(quantizer, "_amax") or quantizer._amax is None: - return False - from modelopt.torch.quantization.triton import IS_AVAILABLE as _TRITON_OK - - return _TRITON_OK + def hessian_forward(self, input, *args, **kwargs): + inp = input.to_local() if hasattr(input, "to_local") else input + if self.input_quantizer is not None and self.input_quantizer.is_enabled: + hessian_input = self.input_quantizer(inp) + else: + hessian_input = inp + gptq_handle.hessian, gptq_handle.n_samples = update_hessian( + hessian_input, gptq_handle.hessian, gptq_handle.n_samples + ) + self.weight_quantizer.disable() + out = self._forward_no_gptq_hessian(input, *args, **kwargs) + self.weight_quantizer.enable() + return out -def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None): - """Update module weights using GPTQ-style blockwise quantization. + bind_forward_method(self.module, hessian_forward, self.CACHE_NAME) - Dispatches to one of three internal paths depending on quantizer type: + def cleanup(self): + """Unpatch the module's forward method.""" + unpatch_forward_method(self.module, self.CACHE_NAME) - 1. **Fused Triton** — for :class:`NVFP4StaticQuantizer` when Triton is - available. Runs the entire column loop in a single GPU kernel per - block (~130x faster than the unfused path on Blackwell GPUs). - 2. **Column-QDQ** — for integer quantizers whose scale geometry allows - single-column fake-quant via :func:`_build_column_qdq`. - 3. **Full-matrix fallback** — calls the quantizer on the full weight matrix - each column (slowest, but always correct). + def quantize(self, block_size, percdamp): + """Run GPTQ blockwise weight update on this module. - Args: - module: Neural network module with ``weight`` and ``weight_quantizer``. - h: Hessian matrix of shape ``(d, d)``. - block_size: Number of columns processed per block. - percdamp: Damping as a fraction of the mean Hessian diagonal. - n_samples: Number of Hessian samples (used only for logging). - """ - weight = module.weight.data.float().clone() - num_rows, num_cols = weight.shape + Populates ``self.weight`` and ``self.h_inv``, runs the blockwise update, + logs MSE, and writes the result back to the module. + """ + hessian = self.hessian.to(self.module.weight.device) + self.weight = self.module.weight.data.float().clone() + self._prepare_hessian_inverse(hessian, percdamp) - h_inv = prepare_hessian_inverse(h, weight, percdamp) + self._blockwise_update(block_size) - quantizer = module.weight_quantizer - if _can_use_fused_gptq(quantizer): - _blockwise_weight_update_fused(weight, h_inv, quantizer, num_rows, num_cols, block_size) - else: - col_qdq_fn, col_qdq_supported = _build_column_qdq(quantizer, weight.shape) - _blockwise_weight_update_unfused( - weight, h_inv, quantizer, num_cols, block_size, col_qdq_fn, col_qdq_supported + self._print_mse_error(hessian) + self.module.weight.data = self.weight.reshape(self.module.weight.shape).to( + self.module.weight.data.dtype ) - _print_relative_mse_error(weight, module.weight.float(), h, module.name, n_samples) - module.weight.data = weight.reshape(module.weight.shape).to(module.weight.data.dtype) - + # ------------------------------------------------------------------ + # Quantize helpers — all read from self.module, self.weight, self.h_inv + # ------------------------------------------------------------------ -def _blockwise_weight_update_fused(weight, h_inv, quantizer, num_rows, num_cols, block_size): - """Fused Triton path for NVFP4: one kernel launch per block.""" - from modelopt.torch.quantization.triton.gptq_fused_kernel import gptq_fused_block + def _prepare_hessian_inverse(self, hessian, percdamp): + """Compute damped inverse Hessian and store as ``self.h_inv``. - group_size = quantizer.block_sizes.get(-1, None) or quantizer.block_sizes.get(1, None) - num_groups = math.ceil(num_cols / group_size) - amax_grouped = quantizer._amax.float().reshape(num_rows, num_groups).contiguous() - global_amax = quantizer.global_amax.float() + Dead-neuron columns (all-zero in ``self.weight``) are zeroed in the + Hessian before inversion, matching the FP-Quant reference: + https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L200 + """ + assert self.weight is not None, "_prepare_hessian_inverse called before quantize()" + h = hessian.clone() + zero_cols = torch.nonzero(self.weight.eq(0).all(dim=0)).unsqueeze(-1) - for block_start in range(0, num_cols, block_size): - block_end = min(block_start + block_size, num_cols) - n_cols_blk = block_end - block_start + h[zero_cols, :] = 0 + h[:, zero_cols] = 0 + h[zero_cols, zero_cols] = 1 - w_block = weight[:, block_start:block_end].clone().contiguous() - h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end].contiguous() + damp = percdamp * torch.mean(torch.diag(h)) + diag_indices = torch.arange(h.shape[0], device=h.device) + h[diag_indices, diag_indices] += damp - qw_block, err_block = gptq_fused_block( - w_block, - amax_grouped, - global_amax, - h_inv_cho_blk, - group_size, - block_start, - n_cols_blk, + try: + h = torch.cholesky_inverse(torch.linalg.cholesky(h)) + self.h_inv = torch.linalg.cholesky(h, upper=True) + except (RuntimeError, torch.linalg.LinAlgError): + print_rank_0("Warning: Hessian is not positive definite, using identity matrix") + self.h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype) + + def _blockwise_update(self, block_size): + """Column-wise GPTQ update using full-matrix QDQ. + + For each column, quantizes the full weight matrix via the quantizer and + extracts the quantized column. This is the standard GPTQ approach. + + Reads/writes ``self.weight`` and ``self.h_inv`` in-place. + """ + assert self.weight is not None and self.h_inv is not None, ( + "_blockwise_update called before _prepare_hessian_inverse()" ) + quantizer = self.module.weight_quantizer + num_cols = self.weight.shape[1] - weight[:, block_start:block_end] = qw_block - if block_end < num_cols: - weight[:, block_end:].addmm_( - err_block[:, :n_cols_blk], - h_inv[block_start:block_end, block_end:], - alpha=-1, - ) + for block_start in range(0, num_cols, block_size): + block_end = min(block_start + block_size, num_cols) + n_cols_blk = block_end - block_start + h_inv_cho_blk = self.h_inv[block_start:block_end, block_start:block_end] - -def _blockwise_weight_update_unfused( - weight, h_inv, quantizer, num_cols, block_size, col_qdq_fn, col_qdq_supported -): - """Column-QDQ or full-matrix fallback for non-NVFP4 quantizers.""" - for block_start in range(0, num_cols, block_size): - block_end = min(block_start + block_size, num_cols) - n_cols_blk = block_end - block_start - h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end] - - # wblk is a scratch copy for intra-block error propagation; weight gets - # the final quantized values. Inter-block errors are propagated via addmm_ below. - if col_qdq_supported: - wblk = weight[:, block_start:block_end].clone() - errs = torch.zeros_like(wblk) - - for i in range(n_cols_blk): - w_ci = wblk[:, i] - d = h_inv_cho_blk[i, i] - qdq_col = col_qdq_fn(w_ci, block_start + i) - weight[:, block_start + i] = qdq_col - err = (w_ci - qdq_col) / d - wblk[:, i:].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) - errs[:, i] = err - else: - wblk = weight.clone() + wblk = self.weight.clone() errs = torch.zeros_like(wblk[:, block_start:block_end]) for i in range(n_cols_blk): w_ci = wblk[:, block_start + i] d = h_inv_cho_blk[i, i] qdq = quantizer(wblk) - weight[:, block_start + i] = qdq[:, block_start + i] + self.weight[:, block_start + i] = qdq[:, block_start + i] err = (w_ci - qdq[:, block_start + i]) / d wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) errs[:, i] = err - weight[:, block_end:].addmm_(errs, h_inv[block_start:block_end, block_end:], alpha=-1) + self.weight[:, block_end:].addmm_( + errs, self.h_inv[block_start:block_end, block_end:], alpha=-1 + ) + def _print_mse_error(self, hessian): + """Log Hessian-weighted relative MSE between ``self.weight`` and original weights.""" + w_orig = self.module.weight.float() + delta = self.weight - w_orig + mse = (delta).mm(hessian).mul(delta).mean() / (w_orig.mm(hessian).mul(w_orig).mean() + 1e-6) + suffix = f", n_hessian_samples: {self.n_samples}" if self.n_samples else "" + print_rank_0(f"[{self.name}] Relative MSE error: {mse.item():.2e}{suffix}") -def gptq_lite( - model: nn.Module, - forward_loop: ForwardLoop | None = None, - percdamp: float = 0.01, - block_size: int = 128, - hessian_state_path: str | None = None, -): - """GPTQ-lite quantization - a simplified GPTQ variant. - Key differences from GPTQ: - - Layers are quantized in parallel (not sequentially with updated activations) - - Uses group-wise updates instead of column-wise updates +def update_hessian(input, hessian, n_samples): + """Update hessian matrix with new input samples using incremental formula. Args: - model: Model to be calibrated. - forward_loop: Callable that forwards calibration data through the model. - percdamp: Percentage of avg Hessian diagonal for damping (default: 0.01). - block_size: Block size for GPTQ weight update. - hessian_state_path: Path to save/load Hessian state. If None, compute without saving. - If path exists, load from it. If path doesnt exist then save computed hessians to path. - - See :class:`GPTQLiteConfig ` for - details on the remaining arguments. - - Note: This feature is currently experimental and may not translate to improved accuracy as expected. + input: Input tensor (batch_size, ..., features) + hessian: Current Hessian matrix to update in-place + n_samples: Number of samples already processed + Returns: + Tuple of (updated_hessian, new_sample_count) """ - # Dictionary to store hessian matrices: {layer_name: {"hessian": Tensor, "n_samples": int}} - hessian_state = {} - - def initialize_hessian_state(tensor_mapping): - """Initialize hessian state with zeros.""" - for name, (shape, device) in tensor_mapping.items(): - # Use CPU if GPU memory is tight - target_device = "cpu" if get_used_gpu_mem_fraction(device) > 0.65 else device - hessian_state[name] = { - "hessian": torch.zeros(shape, dtype=torch.float32, device=target_device), - "n_samples": 0, - } - - def load_hessian_state(path, tensor_mapping): - """Load hessian state from file.""" - print_rank_0(f"Loading hessian state from {path}") - loaded_state = torch.load(path, map_location="cpu") - - for name, (shape, device) in tensor_mapping.items(): - if name not in loaded_state: - raise KeyError(f"Layer '{name}' not found in loaded hessian state") - - # Move to appropriate device based on memory - target_device = "cpu" if get_used_gpu_mem_fraction(device) > 0.65 else device - hessian_state[name] = { - "hessian": loaded_state[name]["hessian"].to(target_device), - "n_samples": loaded_state[name]["n_samples"], - } + # Flatten to 2D (total_tokens, features) first, so batch_size counts tokens + input_flat = input.reshape(-1, input.shape[-1]).t().float() + batch_size = input_flat.shape[1] - print_rank_0(f"Successfully loaded hessian state with {len(hessian_state)} layers") + # Incremental averaging: scale down old hessian + hessian *= n_samples / (n_samples + batch_size) + n_samples += batch_size - def save_hessian_state(path): - """Save hessian state to file.""" - print_rank_0(f"Saving hessian state to {path}") - try: - # Move to CPU for saving - cpu_state = { - name: {"hessian": state["hessian"].cpu(), "n_samples": state["n_samples"]} - for name, state in hessian_state.items() - } + # Compute outer product: H += (2/n_samples) * X @ X^T + scaled_input = math.sqrt(2 / n_samples) * input_flat + hessian.add_((scaled_input @ scaled_input.t()).to(hessian.device)) - os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True) - torch.save(cpu_state, path) - print_rank_0(f"Successfully saved hessian state to {path}") - except Exception as e: - print_rank_0(f"Error saving hessian state: {e}") - print_rank_0("Continuing execution...") + return hessian, n_samples - def hessian_hook(module, input, output): - """Hook to intercept activations and update hessian matrix.""" - state = hessian_state[module.name] - hessian, n_samples = update_hessian(input[0], state["hessian"], state["n_samples"]) - hessian_state[module.name] = {"hessian": hessian, "n_samples": n_samples} - # Phase 1: Collect statistics for quantizers - max_calibrate(model) +def _get_quantized_linear_layers(parent: nn.Module) -> list[tuple[str, nn.Module]]: + """Return (name, module) pairs for all quantized linear layers with enabled weight quantizers. - # Phase 2: Build tensor mapping for all quantized layers - tensor_mapping = {} - for name, module in model.named_modules(): + Also sets ``module.name`` on each returned module for downstream logging. + """ + layers = [] + for name, module in parent.named_modules(): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - in_features = module.weight.shape[-1] - tensor_mapping[name] = ((in_features, in_features), module.weight.device) - module.name = name # Attach name for easy access in hooks - - # Phase 3: Load or compute Hessians - hessian_exists = hessian_state_path is not None and os.path.exists(hessian_state_path) - save_hessians = hessian_state_path is not None and not hessian_exists - - if hessian_exists: - print_rank_0(f"Loading hessian state from {hessian_state_path}") - load_hessian_state(hessian_state_path, tensor_mapping) - else: - if forward_loop is None: - raise ValueError("forward_loop must be provided when computing Hessians") - - # Initialize hessian state - initialize_hessian_state(tensor_mapping) - - # Register hooks to collect activations - handles = [] - for name, module in model.named_modules(): - if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - handles.append(module.register_forward_hook(hessian_hook)) - - # Run forward loop to compute hessians - print_rank_0("Computing Hessian matrices...") - forward_loop(model) - - for handle in handles: - handle.remove() - - # Save if configured - if save_hessians: - try: - save_hessian_state(hessian_state_path) - except Exception as e: - print_rank_0(f"Error saving hessian state: {e}") - print_rank_0("Continuing execution...") - - # Phase 4: Update weights using computed Hessians - print_rank_0("Updating weights using GPTQ-lite algorithm...") - - quantized_modules = [ - (name, module) - for name, module in model.named_modules() - if is_quantized_linear(module) and module.weight_quantizer.is_enabled - ] - - # Perform blockwise weight updates - for name, module in tqdm(quantized_modules, desc="Quantizing layers"): - state = hessian_state[module.name] - hessian = state["hessian"].to(module.weight.device) - blockwise_weight_update(module, hessian, block_size, percdamp) - # Delete hessian state to free memory - del hessian_state[module.name] - torch.cuda.empty_cache() - - print_rank_0("GPTQ-lite quantization completed successfully") + module.name = name + layers.append((name, module)) + return layers @torch.no_grad() @@ -2107,17 +1832,22 @@ def _promote_nvfp4_static_quantizers(model: nn.Module) -> int: @torch.no_grad() def gptq( - layer: nn.Module, + model: nn.Module, forward_loop: ForwardLoop, percdamp: float = 0.01, block_size: int = 128, - **kwargs, ): - """GPTQ quantization for a single decoder layer. + """GPTQ quantization. + + Works in two modes depending on ``use_sequential`` in the config: + + * **Sequential** (``use_sequential=True``): ``sequential_calibrate`` calls this + function once per decoder layer with updated activations, producing more + accurate Hessian estimates. + * **Non-sequential** (``use_sequential=False``): called once on the full model. + All layers are quantized in parallel from the original activations. - Invoked by ``sequential_calibrate`` which walks layers one at a time so each - layer sees activations already updated by the quantization of preceding layers. - Within a layer the steps are: + Per-module steps: 1. ``max_calibrate`` to set amax values from the current activations. 2. Promote eligible quantizers to ``NVFP4StaticQuantizer`` (two-level scaling). @@ -2125,106 +1855,38 @@ def gptq( 4. Blockwise weight updates using the inverse Hessian to compensate for rounding error (the core GPTQ column-wise update). - In contrast to ``gptq_lite``, which quantizes all layers in parallel using the - original (unquantized) activations, this method performs sequential calibration - and therefore produces more accurate Hessian estimates. - Args: - layer: A single decoder layer to quantize. - forward_loop: Callable that replays calibration inputs through the layer. - Provided by ``sequential_calibrate`` which captures per-layer activations. + model: The module to quantize — either the full model or a single decoder + layer when invoked by ``sequential_calibrate``. + forward_loop: Callable that replays calibration inputs through *model*. percdamp: Percentage of avg Hessian diagonal for damping (default: 0.01). block_size: Block size for GPTQ weight update. """ total_start = time.time() - # Set weight amax and activation amax for the current layer using max_calibrate - max_calibrate(layer, forward_loop=forward_loop) - - # Promote NVFP4 static quantizers so they use the two-level scaling path - n_promoted = _promote_nvfp4_static_quantizers(layer) - if n_promoted: - print_rank_0(f"Promoted {n_promoted} quantizer(s) to NVFP4StaticQuantizer") + max_calibrate(model, forward_loop=forward_loop) + _promote_nvfp4_static_quantizers(model) - # Dictionary to store hessian matrices for all linear layers in this decoder - hessian_state = {} - - # Phase 1: Build tensor mapping for all quantized linear layers in this decoder layer - tensor_mapping = {} - for name, module in layer.named_modules(): - if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - in_features = module.weight.shape[-1] - tensor_mapping[name] = ((in_features, in_features), module.weight.device) - module.name = name # Attach name for easy access in hooks - - if not tensor_mapping: - print_rank_0("No quantized linear layers found in decoder layer, skipping GPTQ") + quantized_layers = _get_quantized_linear_layers(model) + if not quantized_layers: + print_rank_0("No quantized linear layers found, skipping GPTQ") return - # Initialize hessian state with zeros - for name, (shape, device) in tensor_mapping.items(): - hessian_state[name] = { - "hessian": torch.zeros(shape, dtype=torch.float32, device=device), - "n_samples": 0, - } - - # Phase 2: Patch forwards to collect Hessians (similar to local_hessian_calibrate) - def _make_hessian_forward(module_name): - def hessian_forward(self, input, *args, **kwargs): - inp = input.to_local() if hasattr(input, "to_local") else input - if self.input_quantizer is not None and self.input_quantizer.is_enabled: - hessian_input = self.input_quantizer(inp) - else: - hessian_input = inp - state = hessian_state[module_name] - hessian, n_samples = update_hessian(hessian_input, state["hessian"], state["n_samples"]) - hessian_state[module_name] = {"hessian": hessian, "n_samples": n_samples} - - self.weight_quantizer.disable() - out = self._forward_no_gptq_hessian(input, *args, **kwargs) - self.weight_quantizer.enable() - return out - - return hessian_forward + gptq_handles = {name: GPTQHandle(m, name, offload_to_cpu=True) for name, m in quantized_layers} + for handle in gptq_handles.values(): + handle.setup() - patched_modules = [] - for name, module in layer.named_modules(): - if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - bind_forward_method(module, _make_hessian_forward(name), "_forward_no_gptq_hessian") - patched_modules.append(module) - - # Run forward passes to collect Hessians - hessian_start = time.time() - print_rank_0(f"Computing Hessians for {len(tensor_mapping)} linear layers...") - forward_loop(layer) - - # Unpatch forwards - for module in patched_modules: - unpatch_forward_method(module, "_forward_no_gptq_hessian") + print_rank_0(f"Computing Hessians for {len(gptq_handles)} linear layers...") + forward_loop(model) - torch.cuda.synchronize() if torch.cuda.is_available() else None - hessian_time = time.time() - hessian_start + for handle in gptq_handles.values(): + handle.cleanup() - # Phase 3: Update weights using computed Hessians (same as gptq_lite) - weight_update_start = time.time() print_rank_0("Updating weights using GPTQ algorithm...") - for name, module in layer.named_modules(): - if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - state = hessian_state[module.name] - hessian = state["hessian"].to(module.weight.device) - blockwise_weight_update( - module, hessian, block_size, percdamp, n_samples=state["n_samples"] - ) - del hessian_state[module.name] + for handle in gptq_handles.values(): + handle.quantize(block_size, percdamp) + del gptq_handles + if torch.cuda.is_available(): torch.cuda.empty_cache() - - torch.cuda.synchronize() if torch.cuda.is_available() else None - weight_update_time = time.time() - weight_update_start - - total_time = time.time() - total_start - print_rank_0( - f"GPTQ timing - Hessian: {hessian_time:.2f}s, " - f"Weight update: {weight_update_time:.2f}s, " - f"Total: {total_time:.2f}s" - ) + print_rank_0(f"GPTQ time: {time.time() - total_start:.2f}s") diff --git a/modelopt/torch/quantization/triton/gptq_fused_kernel.py b/modelopt/torch/quantization/triton/gptq_fused_kernel.py deleted file mode 100644 index 21d84713a1..0000000000 --- a/modelopt/torch/quantization/triton/gptq_fused_kernel.py +++ /dev/null @@ -1,189 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Fused Triton kernel for the GPTQ blockwise weight-update inner loop. - -The standard GPTQ inner loop launches ~10-15 CUDA kernels per column -(amax lookup, FP4 quantization, error computation, rank-1 update). -For ``block_size=128`` that is ~1 500 kernel launches per block, each with -~5-10 us of launch overhead dominating actual compute. - -This module fuses the entire inner loop into a **single** Triton kernel per -block. Rows are independent and map to Triton programs; columns are processed -sequentially inside each program so the rank-1 error update is carried forward -without synchronisation. - -Supported quantisation format: **NVFP4 static block quantisation** (two-level -scaling with per-group amax and a global amax). -""" - -import torch -import triton -import triton.language as tl - -__all__ = ["gptq_fused_block"] - -# -- NVFP4 constants used by the kernel ------------------------------------ -# Maximum representable FP4-E2M1 value (1 + 1 + 0.5 = 6.0 when decoded via -# the standard E2M1 table: {0, 0.5, 1, 1.5, 2, 3, 4, 6}). -_FP4_MAX = 6.0 -# FP8-E4M3 has max representable value 448. -_FP8_E4M3_MAX = 448.0 - - -@triton.jit -def _gptq_fused_block_kernel( - w_ptr, # [num_rows, BLOCK_SIZE] working weight block (in-place) - qw_ptr, # [num_rows, BLOCK_SIZE] output: quantized weights - err_ptr, # [num_rows, BLOCK_SIZE] output: quantization errors - amax_ptr, # [num_rows, num_groups] per-group amax, row-major - global_amax_ptr, # scalar float32 on device - hinv_ptr, # [BLOCK_SIZE, BLOCK_SIZE] upper Cholesky of H^{-1} - num_rows, - num_groups, - group_size: tl.constexpr, - block_start, # column offset of this block in the full weight matrix - n_cols, # actual columns in this block (may be < BLOCK_SIZE) - BLOCK_SIZE: tl.constexpr, -): - """One program per row; sequentially quantizes columns, propagating errors.""" - row = tl.program_id(0) - if row >= num_rows: - return - - # Base pointers for this row - w_base = w_ptr + row * BLOCK_SIZE - qw_base = qw_ptr + row * BLOCK_SIZE - err_base = err_ptr + row * BLOCK_SIZE - amax_row_base = amax_ptr + row * num_groups - - # Pre-compute global FP8 scale factors (constant across columns) - global_amax = tl.load(global_amax_ptr).to(tl.float32) - global_scale = global_amax / 6.0 # _FP4_MAX - fp8_inv_scale = tl.where(global_scale > 0.0, 1.0 / (448.0 / global_scale), 0.0) - - j_range = tl.arange(0, BLOCK_SIZE) - - for i in range(BLOCK_SIZE): - wi = tl.load(w_base + i) - - # -- Compute NVFP4 two-level scale for this column's group ----------- - col_idx = block_start + i - group_idx = col_idx // group_size - raw_amax = tl.load(amax_row_base + group_idx).to(tl.float32) - raw_scale = raw_amax / 6.0 # _FP4_MAX - - # FP8-quantize the block scale: scale * fp8_scale -> cast E4M3 -> back - fp8_scale = tl.where(global_scale > 0.0, 448.0 / global_scale, 1.0) - si = (raw_scale * fp8_scale).to(tl.float8e4nv).to(tl.float32) * fp8_inv_scale - - # Guard: replace zero / nan / inf scale with 1.0 - # NOTE: ``si != si`` is the standard NaN check in Triton (no math.isnan). - si_safe = tl.where( - (si == 0.0) | (si != si) | (tl.abs(si) == float("inf")), # noqa: PLR0124 - 1.0, - si, - ) - - # -- FP4-E2M1 fake quantization (nearest-round to 8 levels) ---------- - abs_scaled = tl.abs(wi) / si_safe - q_val = tl.where( - abs_scaled <= 0.25, - 0.0, - tl.where( - abs_scaled < 0.75, - 0.5, - tl.where( - abs_scaled <= 1.25, - 1.0, - tl.where( - abs_scaled < 1.75, - 1.5, - tl.where( - abs_scaled <= 2.5, - 2.0, - tl.where(abs_scaled < 3.5, 3.0, tl.where(abs_scaled <= 5.0, 4.0, 6.0)), - ), - ), - ), - ), - ) - - qi = q_val * si_safe * tl.where(wi >= 0.0, 1.0, -1.0) - tl.store(qw_base + i, qi) - - # -- GPTQ error and rank-1 update ------------------------------------ - di = tl.load(hinv_ptr + i * BLOCK_SIZE + i) - err_i = (wi - qi) / di - tl.store(err_base + i, err_i) - - j_mask = (j_range > i) & (j_range < n_cols) - hinv_row = tl.load(hinv_ptr + i * BLOCK_SIZE + j_range, mask=j_mask, other=0.0) - w_rem = tl.load(w_base + j_range, mask=j_mask, other=0.0) - w_rem = w_rem - err_i * hinv_row - tl.store(w_base + j_range, w_rem, mask=j_mask) - - -def gptq_fused_block( - w_block: torch.Tensor, - amax_grouped: torch.Tensor, - global_amax: torch.Tensor, - h_inv_cho_blk: torch.Tensor, - group_size: int, - block_start: int, - n_cols: int, -) -> tuple[torch.Tensor, torch.Tensor]: - """Run the GPTQ column loop for one block in a single Triton kernel launch. - - Args: - w_block: Working weight block of shape ``[num_rows, block_size]`` (will be cloned). - amax_grouped: Per-group amax of shape ``[num_rows, num_groups]``. - global_amax: Scalar tensor with the global amax. - h_inv_cho_blk: Upper Cholesky factor of H^{-1}, shape ``[block_size, block_size]``. - group_size: NVFP4 quantization group size (typically 16). - block_start: Column offset of this block in the full weight matrix. - n_cols: Actual number of columns in this block (``<= block_size``). - - Returns: - Tuple of ``(qw_block, err_block)`` each of shape ``[num_rows, block_size]``. - """ - num_rows, block_size = w_block.shape - num_groups = amax_grouped.shape[1] - - w_block = w_block.contiguous() - amax_grouped = amax_grouped.contiguous() - h_inv_cho_blk = h_inv_cho_blk.contiguous() - - qw_block = torch.empty_like(w_block) - err_block = torch.empty_like(w_block) - - grid = (num_rows,) - with torch.cuda.device(w_block.device): - _gptq_fused_block_kernel[grid]( - w_block, - qw_block, - err_block, - amax_grouped, - global_amax, - h_inv_cho_blk, - num_rows, - num_groups, - group_size, - block_start, - n_cols, - BLOCK_SIZE=block_size, - ) - - return qw_block, err_block diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index 1203c20ef7..639d376acb 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -21,14 +21,7 @@ import modelopt.torch.quantization as mtq from modelopt.torch.export.unified_export_hf import _export_quantized_weight -from modelopt.torch.quantization.model_calib import ( - _blockwise_weight_update_fused, - _blockwise_weight_update_unfused, - blockwise_weight_update, - prepare_hessian_inverse, - update_hessian, -) -from modelopt.torch.quantization.nn import NVFP4StaticQuantizer +from modelopt.torch.quantization.model_calib import GPTQHandle, update_hessian from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor from modelopt.torch.utils.dataset_utils import create_forward_loop, get_dataset_dataloader @@ -163,8 +156,10 @@ def test_gptq_updates(block_size, dim, model_weight, expect_weight_change): f"n_samples should be {expected_tokens_2}, got {n_samples}" ) - hessian = hessian.to(input.device) - blockwise_weight_update(model, hessian, block_size, 0.1) + handle = GPTQHandle(model, "linear") + handle.hessian = hessian.to(input.device) + handle.n_samples = n_samples + handle.quantize(block_size, 0.1) if expect_weight_change: # Weight must change as GPTQ updates weights to adjust for quantization error assert not torch.allclose(model.weight.data, q_dq_weight), "Weight should not be equal" @@ -196,7 +191,10 @@ def test_gptq_export_roundtrip(): hessian, n_samples = update_hessian(input_tensor, hessian, n_samples) hessian = hessian.to("cuda") - blockwise_weight_update(model, hessian, block_size, percdamp=0.1) + handle = GPTQHandle(model, "linear") + handle.hessian = hessian + handle.n_samples = n_samples + handle.quantize(block_size, percdamp=0.1) # Save the QDQ reference from the quantizer applied to GPTQ'd weights gptq_weight_shape = model.weight.data.shape @@ -307,87 +305,3 @@ def test_gptq_e2e_flow(quant_cfg): print( f"Generated ids after quantization: {tokenizer.decode(generated_ids_after_ptq[0], skip_special_tokens=True)}" ) - - -@pytest.mark.parametrize("dim", [256, 512]) -def test_fused_vs_unfused_nvfp4(dim): - """Verify that the fused Triton GPTQ kernel produces equivalent results to the unfused path. - - The fused kernel computes NVFP4 quantisation inline using Triton intrinsics, - which can differ slightly from the PyTorch-level quantiser path (different FP - rounding order). On real models (dim >= 4096) the relative MSE difference is - typically < 0.1%; at the smaller dims used here the tolerance is set to 20%. - """ - from modelopt.torch.quantization.model_calib import _promote_nvfp4_static_quantizers - - torch.manual_seed(RAND_SEED) - block_size = min(128, dim) - - # NVFP4_WEIGHT_ONLY_GPTQ_CFG uses *static* blocks, which get promoted to - # NVFP4StaticQuantizer — the prerequisite for the fused Triton path. - quant_cfg = copy.deepcopy(mtq.NVFP4_WEIGHT_ONLY_GPTQ_CFG) - quant_cfg["algorithm"] = "max" # calibrate only, don't run GPTQ - - model = torch.nn.Linear(dim, dim, bias=False).to("cuda") - model.name = "test_fused" - original_weight = model.weight.data.clone() - inp = torch.randn(4, 32, dim, device="cuda") - - mtq.quantize(model, quant_cfg, forward_loop=lambda m: m(inp)) - - # Promote to NVFP4StaticQuantizer (normally done by gptq / sequential_calibrate) - n_promoted = _promote_nvfp4_static_quantizers(model) - assert n_promoted > 0, "Expected at least one quantizer to be promoted" - - quantizer = model.weight_quantizer - assert isinstance(quantizer, NVFP4StaticQuantizer), ( - f"Expected NVFP4StaticQuantizer, got {type(quantizer).__name__}" - ) - - # Restore original weight and compute Hessian - model.weight.data = original_weight.clone() - hessian = torch.zeros(dim, dim, dtype=torch.float32) - n_samples = 0 - hessian, n_samples = update_hessian(inp, hessian, n_samples) - hessian = hessian.to("cuda") - - # --- Run fused path --- - weight_fused = original_weight.float().clone() - num_rows, num_cols = weight_fused.shape - h_inv = prepare_hessian_inverse(hessian, weight_fused, percdamp=0.01) - _blockwise_weight_update_fused(weight_fused, h_inv, quantizer, num_rows, num_cols, block_size) - - # --- Run unfused path --- - weight_unfused = original_weight.float().clone() - h_inv_unfused = prepare_hessian_inverse(hessian, weight_unfused, percdamp=0.01) - _blockwise_weight_update_unfused( - weight_unfused, h_inv_unfused, quantizer, num_cols, block_size, None, False - ) - - # Both paths must produce non-trivial updates - assert not torch.equal(weight_fused, original_weight.float()), ( - "Fused path did not update weights" - ) - assert not torch.equal(weight_unfused, original_weight.float()), ( - "Unfused path did not update weights" - ) - - # Compare Hessian-weighted relative MSE - def _relative_mse(q, w, h): - delta = q - w - return (delta.mm(h).mul(delta).mean() / (w.mm(h).mul(w).mean() + 1e-6)).item() - - orig_f = original_weight.float() - mse_fused = _relative_mse(weight_fused, orig_f, hessian) - mse_unfused = _relative_mse(weight_unfused, orig_f, hessian) - - assert mse_fused > 0, "Fused MSE should be positive" - assert mse_unfused > 0, "Unfused MSE should be positive" - - # At small test dimensions, inline Triton FP4 rounding can diverge up to ~15% - # from the PyTorch path. On production-scale layers this drops below 0.1%. - relative_mse_diff = abs(mse_fused - mse_unfused) / max(mse_fused, mse_unfused) - assert relative_mse_diff < 0.20, ( - f"Fused ({mse_fused:.6e}) and unfused ({mse_unfused:.6e}) MSE differ by " - f"{relative_mse_diff:.2%}, expected < 20%" - ) From 0c6ec11e37c528e98d3ee85d84ea7dce3fe1338d Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 23 Mar 2026 21:10:53 +0000 Subject: [PATCH 36/52] claude review + coderabbit review Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 36 ---------------------- modelopt/torch/quantization/model_calib.py | 25 ++++++--------- modelopt/torch/quantization/model_quant.py | 2 +- tests/gpu/torch/quantization/test_gptq.py | 2 +- 4 files changed, 11 insertions(+), 54 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index f993461e65..fb2e8096b4 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -626,39 +626,6 @@ def _nvfp4_selective_quant_cfg( }, } -NVFP4_WEIGHT_ONLY_GPTQ_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": {"enable": False}, - **_default_disabled_quantizer_cfg, - }, - "algorithm": {"method": "gptq", "use_sequential": True}, -} - -NVFP4_GPTQ_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": {"method": "gptq", "use_sequential": True}, -} - MAMBA_MOE_NVFP4_AGGRESSIVE_CFG = { "quant_cfg": [ *_base_disable_all, @@ -849,9 +816,6 @@ def _nvfp4_selective_quant_cfg( "NVFP4_AWQ_FULL_CFG", "NVFP4_AWQ_LITE_CFG", "NVFP4_DEFAULT_CFG", - "NVFP4_GPTQ_CFG", - "NVFP4_WEIGHT_ONLY_CFG", - "NVFP4_WEIGHT_ONLY_GPTQ_CFG", "NVFP4_FP8_MHA_CONFIG", "NVFP4_KV_CFG", "NVFP4_KV_ROTATE_CFG", diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 15ab2bec65..e41ee53005 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1624,7 +1624,7 @@ def cleanup(self): """Unpatch the module's forward method.""" unpatch_forward_method(self.module, self.CACHE_NAME) - def quantize(self, block_size, percdamp): + def update_weights(self, block_size, percdamp): """Run GPTQ blockwise weight update on this module. Populates ``self.weight`` and ``self.h_inv``, runs the blockwise update, @@ -1724,6 +1724,8 @@ def update_hessian(input, hessian, n_samples): n_samples: Number of samples already processed Returns: Tuple of (updated_hessian, new_sample_count) + + Note: input must be non-empty (batch_size > 0); a zero-sized input causes division by zero. """ # Flatten to 2D (total_tokens, features) first, so batch_size counts tokens input_flat = input.reshape(-1, input.shape[-1]).t().float() @@ -1740,19 +1742,6 @@ def update_hessian(input, hessian, n_samples): return hessian, n_samples -def _get_quantized_linear_layers(parent: nn.Module) -> list[tuple[str, nn.Module]]: - """Return (name, module) pairs for all quantized linear layers with enabled weight quantizers. - - Also sets ``module.name`` on each returned module for downstream logging. - """ - layers = [] - for name, module in parent.named_modules(): - if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - module.name = name - layers.append((name, module)) - return layers - - @torch.no_grad() def sequential_calibrate( model: nn.Module, @@ -1867,7 +1856,11 @@ def gptq( max_calibrate(model, forward_loop=forward_loop) _promote_nvfp4_static_quantizers(model) - quantized_layers = _get_quantized_linear_layers(model) + quantized_layers = [ + (n, m) + for n, m in model.named_modules() + if is_quantized_linear(m) and m.weight_quantizer.is_enabled + ] if not quantized_layers: print_rank_0("No quantized linear layers found, skipping GPTQ") return @@ -1884,7 +1877,7 @@ def gptq( print_rank_0("Updating weights using GPTQ algorithm...") for handle in gptq_handles.values(): - handle.quantize(block_size, percdamp) + handle.update_weights(block_size, percdamp) del gptq_handles if torch.cuda.is_available(): diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index 746e391f3f..21a1bd1658 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -565,7 +565,7 @@ def get_auto_quantize_config(search_state, constraints=None, verbose=False): config = mtq.get_auto_quantize_config(search_state) # [Optional] Customize algorithm if needed - config["algorithm"] = {"method": "gptq_lite", "sequential": True} + config["algorithm"] = {"method": "gptq", "sequential": True} # Reuse on the same model (e.g. run a longer calibration pass) model = mtq.quantize(model, config, forward_loop=calibrate_loop) diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index 639d376acb..a32d52e1f5 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -276,7 +276,7 @@ def test_gptq_e2e_flow(quant_cfg): model.eval() quant_cfg = copy.deepcopy(quant_cfg) - quant_cfg["algorithm"] = "gptq_lite" + quant_cfg["algorithm"] = {"method": "gptq", "use_sequential": True} # Define quantizer/dataloader calib_dataloader = get_dataset_dataloader( dataset_name="cnn_dailymail", From e2eb25ab17265f6c9df120ff15458a14171bebaa Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 23 Mar 2026 21:24:09 +0000 Subject: [PATCH 37/52] refactor Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 296 +++++++++++---------- modelopt/torch/quantization/model_quant.py | 3 - tests/gpu/torch/quantization/test_gptq.py | 42 +-- 3 files changed, 157 insertions(+), 184 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index e41ee53005..28d697ddee 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1570,151 +1570,6 @@ def postprocess(module, name): max_calibrate(model, forward_loop) -class GPTQHandle: - """Encapsulates per-module GPTQ state and operations. - - Owns the Hessian, patches the forward during collection, and contains - the blockwise weight-update logic. - - Instance attributes set during ``__init__``: - module, name, hessian, n_samples - - Instance attributes set during ``quantize``: - weight: float working copy of module weights (mutated in-place by update methods) - h_inv: upper-triangular Cholesky factor of the damped inverse Hessian - """ - - CACHE_NAME = "_forward_no_gptq_hessian" - - def __init__(self, module, name, offload_to_cpu=False): - self.module = module - self.name = name - in_features = module.weight.shape[-1] - device = module.weight.device - if offload_to_cpu and get_used_gpu_mem_fraction(device) > 0.65: - device = "cpu" - self.hessian = torch.zeros(in_features, in_features, dtype=torch.float32, device=device) - self.n_samples = 0 - # Set by quantize(); listed here for documentation. - self.weight: torch.Tensor | None = None - self.h_inv: torch.Tensor | None = None - - def setup(self): - """Patch the module's forward to accumulate Hessian during the collection pass.""" - gptq_handle = self - - def hessian_forward(self, input, *args, **kwargs): - inp = input.to_local() if hasattr(input, "to_local") else input - if self.input_quantizer is not None and self.input_quantizer.is_enabled: - hessian_input = self.input_quantizer(inp) - else: - hessian_input = inp - gptq_handle.hessian, gptq_handle.n_samples = update_hessian( - hessian_input, gptq_handle.hessian, gptq_handle.n_samples - ) - - self.weight_quantizer.disable() - out = self._forward_no_gptq_hessian(input, *args, **kwargs) - self.weight_quantizer.enable() - return out - - bind_forward_method(self.module, hessian_forward, self.CACHE_NAME) - - def cleanup(self): - """Unpatch the module's forward method.""" - unpatch_forward_method(self.module, self.CACHE_NAME) - - def update_weights(self, block_size, percdamp): - """Run GPTQ blockwise weight update on this module. - - Populates ``self.weight`` and ``self.h_inv``, runs the blockwise update, - logs MSE, and writes the result back to the module. - """ - hessian = self.hessian.to(self.module.weight.device) - self.weight = self.module.weight.data.float().clone() - self._prepare_hessian_inverse(hessian, percdamp) - - self._blockwise_update(block_size) - - self._print_mse_error(hessian) - self.module.weight.data = self.weight.reshape(self.module.weight.shape).to( - self.module.weight.data.dtype - ) - - # ------------------------------------------------------------------ - # Quantize helpers — all read from self.module, self.weight, self.h_inv - # ------------------------------------------------------------------ - - def _prepare_hessian_inverse(self, hessian, percdamp): - """Compute damped inverse Hessian and store as ``self.h_inv``. - - Dead-neuron columns (all-zero in ``self.weight``) are zeroed in the - Hessian before inversion, matching the FP-Quant reference: - https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L200 - """ - assert self.weight is not None, "_prepare_hessian_inverse called before quantize()" - h = hessian.clone() - zero_cols = torch.nonzero(self.weight.eq(0).all(dim=0)).unsqueeze(-1) - - h[zero_cols, :] = 0 - h[:, zero_cols] = 0 - h[zero_cols, zero_cols] = 1 - - damp = percdamp * torch.mean(torch.diag(h)) - diag_indices = torch.arange(h.shape[0], device=h.device) - h[diag_indices, diag_indices] += damp - - try: - h = torch.cholesky_inverse(torch.linalg.cholesky(h)) - self.h_inv = torch.linalg.cholesky(h, upper=True) - except (RuntimeError, torch.linalg.LinAlgError): - print_rank_0("Warning: Hessian is not positive definite, using identity matrix") - self.h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype) - - def _blockwise_update(self, block_size): - """Column-wise GPTQ update using full-matrix QDQ. - - For each column, quantizes the full weight matrix via the quantizer and - extracts the quantized column. This is the standard GPTQ approach. - - Reads/writes ``self.weight`` and ``self.h_inv`` in-place. - """ - assert self.weight is not None and self.h_inv is not None, ( - "_blockwise_update called before _prepare_hessian_inverse()" - ) - quantizer = self.module.weight_quantizer - num_cols = self.weight.shape[1] - - for block_start in range(0, num_cols, block_size): - block_end = min(block_start + block_size, num_cols) - n_cols_blk = block_end - block_start - h_inv_cho_blk = self.h_inv[block_start:block_end, block_start:block_end] - - wblk = self.weight.clone() - errs = torch.zeros_like(wblk[:, block_start:block_end]) - - for i in range(n_cols_blk): - w_ci = wblk[:, block_start + i] - d = h_inv_cho_blk[i, i] - qdq = quantizer(wblk) - self.weight[:, block_start + i] = qdq[:, block_start + i] - err = (w_ci - qdq[:, block_start + i]) / d - wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) - errs[:, i] = err - - self.weight[:, block_end:].addmm_( - errs, self.h_inv[block_start:block_end, block_end:], alpha=-1 - ) - - def _print_mse_error(self, hessian): - """Log Hessian-weighted relative MSE between ``self.weight`` and original weights.""" - w_orig = self.module.weight.float() - delta = self.weight - w_orig - mse = (delta).mm(hessian).mul(delta).mean() / (w_orig.mm(hessian).mul(w_orig).mean() + 1e-6) - suffix = f", n_hessian_samples: {self.n_samples}" if self.n_samples else "" - print_rank_0(f"[{self.name}] Relative MSE error: {mse.item():.2e}{suffix}") - - def update_hessian(input, hessian, n_samples): """Update hessian matrix with new input samples using incremental formula. @@ -1851,6 +1706,155 @@ def gptq( percdamp: Percentage of avg Hessian diagonal for damping (default: 0.01). block_size: Block size for GPTQ weight update. """ + + class GPTQHelper: + """Encapsulates per-module GPTQ state and operations. + + Owns the Hessian, patches the forward during collection, and contains + the blockwise weight-update logic. + + Instance attributes set during ``__init__``: + module, name, hessian, n_samples + + Instance attributes set during ``update_weights``: + weight: float working copy of module weights (mutated in-place by update methods) + h_inv: upper-triangular Cholesky factor of the damped inverse Hessian + """ + + CACHE_NAME = "_forward_no_gptq_hessian" + + def __init__(self, module, name, offload_to_cpu=False): + self.module = module + self.name = name + in_features = module.weight.shape[-1] + device = module.weight.device + if offload_to_cpu and get_used_gpu_mem_fraction(device) > 0.65: + device = "cpu" + self.hessian = torch.zeros(in_features, in_features, dtype=torch.float32, device=device) + self.n_samples = 0 + # Set by update_weights(); listed here for documentation. + self.weight: torch.Tensor | None = None + self.h_inv: torch.Tensor | None = None + + def setup(self): + """Patch the module's forward to accumulate Hessian during the collection pass.""" + gptq_helper = self + + def hessian_forward(self, input, *args, **kwargs): + inp = input.to_local() if hasattr(input, "to_local") else input + if self.input_quantizer is not None and self.input_quantizer.is_enabled: + hessian_input = self.input_quantizer(inp) + else: + hessian_input = inp + gptq_helper.hessian, gptq_helper.n_samples = update_hessian( + hessian_input, gptq_helper.hessian, gptq_helper.n_samples + ) + + self.weight_quantizer.disable() + out = self._forward_no_gptq_hessian(input, *args, **kwargs) + self.weight_quantizer.enable() + return out + + bind_forward_method(self.module, hessian_forward, self.CACHE_NAME) + + def cleanup(self): + """Unpatch the module's forward method.""" + unpatch_forward_method(self.module, self.CACHE_NAME) + + def update_weights(self, block_size, percdamp): + """Run GPTQ blockwise weight update on this module. + + Populates ``self.weight`` and ``self.h_inv``, runs the blockwise update, + logs MSE, and writes the result back to the module. + """ + hessian = self.hessian.to(self.module.weight.device) + self.weight = self.module.weight.data.float().clone() + self._prepare_hessian_inverse(hessian, percdamp) + + self._blockwise_update(block_size) + + self._print_mse_error(hessian) + self.module.weight.data = self.weight.reshape(self.module.weight.shape).to( + self.module.weight.data.dtype + ) + + # ------------------------------------------------------------------ + # Quantize helpers — all read from self.module, self.weight, self.h_inv + # ------------------------------------------------------------------ + + def _prepare_hessian_inverse(self, hessian, percdamp): + """Compute damped inverse Hessian and store as ``self.h_inv``. + + Dead-neuron columns (all-zero in ``self.weight``) are zeroed in the + Hessian before inversion, matching the FP-Quant reference: + https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L200 + """ + assert self.weight is not None, ( + "_prepare_hessian_inverse called before update_weights()" + ) + h = hessian.clone() + zero_cols = torch.nonzero(self.weight.eq(0).all(dim=0)).unsqueeze(-1) + + h[zero_cols, :] = 0 + h[:, zero_cols] = 0 + h[zero_cols, zero_cols] = 1 + + damp = percdamp * torch.mean(torch.diag(h)) + diag_indices = torch.arange(h.shape[0], device=h.device) + h[diag_indices, diag_indices] += damp + + try: + h = torch.cholesky_inverse(torch.linalg.cholesky(h)) + self.h_inv = torch.linalg.cholesky(h, upper=True) + except (RuntimeError, torch.linalg.LinAlgError): + print_rank_0("Warning: Hessian is not positive definite, using identity matrix") + self.h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype) + + def _blockwise_update(self, block_size): + """Column-wise GPTQ update using full-matrix QDQ. + + For each column, quantizes the full weight matrix via the quantizer and + extracts the quantized column. This is the standard GPTQ approach. + + Reads/writes ``self.weight`` and ``self.h_inv`` in-place. + """ + assert self.weight is not None and self.h_inv is not None, ( + "_blockwise_update called before _prepare_hessian_inverse()" + ) + quantizer = self.module.weight_quantizer + num_cols = self.weight.shape[1] + + for block_start in range(0, num_cols, block_size): + block_end = min(block_start + block_size, num_cols) + n_cols_blk = block_end - block_start + h_inv_cho_blk = self.h_inv[block_start:block_end, block_start:block_end] + + wblk = self.weight.clone() + errs = torch.zeros_like(wblk[:, block_start:block_end]) + + for i in range(n_cols_blk): + w_ci = wblk[:, block_start + i] + d = h_inv_cho_blk[i, i] + qdq = quantizer(wblk) + self.weight[:, block_start + i] = qdq[:, block_start + i] + err = (w_ci - qdq[:, block_start + i]) / d + wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) + errs[:, i] = err + + self.weight[:, block_end:].addmm_( + errs, self.h_inv[block_start:block_end, block_end:], alpha=-1 + ) + + def _print_mse_error(self, hessian): + """Log Hessian-weighted relative MSE between ``self.weight`` and original weights.""" + w_orig = self.module.weight.float() + delta = self.weight - w_orig + mse = (delta).mm(hessian).mul(delta).mean() / ( + w_orig.mm(hessian).mul(w_orig).mean() + 1e-6 + ) + suffix = f", n_hessian_samples: {self.n_samples}" if self.n_samples else "" + print_rank_0(f"[{self.name}] Relative MSE error: {mse.item():.2e}{suffix}") + total_start = time.time() max_calibrate(model, forward_loop=forward_loop) @@ -1865,7 +1869,7 @@ def gptq( print_rank_0("No quantized linear layers found, skipping GPTQ") return - gptq_handles = {name: GPTQHandle(m, name, offload_to_cpu=True) for name, m in quantized_layers} + gptq_handles = {name: GPTQHelper(m, name, offload_to_cpu=True) for name, m in quantized_layers} for handle in gptq_handles.values(): handle.setup() diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index 21a1bd1658..5e65f9cc1d 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -564,9 +564,6 @@ def get_auto_quantize_config(search_state, constraints=None, verbose=False): # Or use the original result config = mtq.get_auto_quantize_config(search_state) - # [Optional] Customize algorithm if needed - config["algorithm"] = {"method": "gptq", "sequential": True} - # Reuse on the same model (e.g. run a longer calibration pass) model = mtq.quantize(model, config, forward_loop=calibrate_loop) diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index a32d52e1f5..45f17833fd 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -21,7 +21,7 @@ import modelopt.torch.quantization as mtq from modelopt.torch.export.unified_export_hf import _export_quantized_weight -from modelopt.torch.quantization.model_calib import GPTQHandle, update_hessian +from modelopt.torch.quantization.model_calib import gptq, update_hessian from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor from modelopt.torch.utils.dataset_utils import create_forward_loop, get_dataset_dataloader @@ -127,39 +127,20 @@ def test_update_hessian(): def test_gptq_updates(block_size, dim, model_weight, expect_weight_change): model = torch.nn.Linear(dim, dim).to("cuda") model.weight.data = model_weight - model.name = "linear" original_weight = model_weight.clone() - input = torch.randn(2, 16, dim).to("cuda") - hessian = torch.zeros(dim, dim).to("cpu") - n_samples = 0 + input_tensor = torch.randn(2, 16, dim).to("cuda") quant_cfg = mtq.NVFP4_DEFAULT_CFG - mtq.quantize(model, quant_cfg, forward_loop=lambda model: model(input)) + mtq.quantize(model, quant_cfg, forward_loop=lambda model: model(input_tensor)) # Get qdq weight q_dq_weight = model.weight_quantizer(model.weight.data) - # Restore original weight + # Restore original weight before GPTQ model.weight.data = original_weight.clone() - hessian, n_samples = update_hessian(input, hessian, n_samples) - - # Verify n_samples counts total tokens (batch * seq_len) after flattening - expected_tokens = input.shape[0] * input.shape[1] # 2 * 16 = 32 - assert n_samples == expected_tokens, f"n_samples should be {expected_tokens}, got {n_samples}" - - # Perform another forward pass to update hessian matrix - input_2 = torch.randn(3, 16, dim).to("cuda") - hessian, n_samples = update_hessian(input_2, hessian, n_samples) - expected_tokens_2 = expected_tokens + input_2.shape[0] * input_2.shape[1] # 32 + 48 = 80 - assert n_samples == expected_tokens_2, ( - f"n_samples should be {expected_tokens_2}, got {n_samples}" - ) - - handle = GPTQHandle(model, "linear") - handle.hessian = hessian.to(input.device) - handle.n_samples = n_samples - handle.quantize(block_size, 0.1) + # Run GPTQ through the public API + gptq(model, forward_loop=lambda m: m(input_tensor), percdamp=0.1, block_size=block_size) if expect_weight_change: # Weight must change as GPTQ updates weights to adjust for quantization error assert not torch.allclose(model.weight.data, q_dq_weight), "Weight should not be equal" @@ -175,7 +156,6 @@ def test_gptq_export_roundtrip(): # Step 1: Create a simple linear model and quantize to install NVFP4 quantizers model = torch.nn.Linear(dim, dim).to("cuda") - model.name = "linear" original_weight = model.weight.data.clone() input_tensor = torch.randn(2, 16, dim).to("cuda") quant_cfg = mtq.NVFP4_DEFAULT_CFG @@ -186,15 +166,7 @@ def test_gptq_export_roundtrip(): model.weight.data = original_weight.clone() # Step 2: Perform GPTQ — compute Hessian and update weights - hessian = torch.zeros(dim, dim, dtype=torch.float32) - n_samples = 0 - hessian, n_samples = update_hessian(input_tensor, hessian, n_samples) - hessian = hessian.to("cuda") - - handle = GPTQHandle(model, "linear") - handle.hessian = hessian - handle.n_samples = n_samples - handle.quantize(block_size, percdamp=0.1) + gptq(model, forward_loop=lambda m: m(input_tensor), percdamp=0.1, block_size=block_size) # Save the QDQ reference from the quantizer applied to GPTQ'd weights gptq_weight_shape = model.weight.data.shape From 472ec34d1ee6df2d8f35c142bbd17d233813cb26 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 23 Mar 2026 21:27:55 +0000 Subject: [PATCH 38/52] stray changes removed Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index f9ba9784ba..e2f283e726 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -109,7 +109,6 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None: "nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG, "nvfp4_omlp_only": mtq.NVFP4_OMLP_ONLY_CFG, "nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG, - "nvfp4_gptq": mtq.NVFP4_GPTQ_CFG, "mxfp8": mtq.MXFP8_DEFAULT_CFG, "nvfp4_local_hessian": mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG, } @@ -1029,7 +1028,6 @@ def quantize_main( is_nemotron_vl_model, first_text_speech_dataset, ) - export_quantized( args, full_model, @@ -1188,7 +1186,6 @@ def parse_args() -> argparse.Namespace: default=False, action="store_true", ) - parser.add_argument( "--low_memory_mode", help=( From b1dcdfc71c14f357ce660109fb6de457e4e2f548 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 25 Mar 2026 17:50:43 +0000 Subject: [PATCH 39/52] Address PR comments Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 4 +- modelopt/torch/quantization/model_calib.py | 53 ++++++++----------- .../torch/quantization/utils/core_utils.py | 43 +++++++++++++++ tests/gpu/torch/quantization/test_gptq.py | 20 +------ 4 files changed, 69 insertions(+), 51 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index fb2e8096b4..33d198eef4 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1544,14 +1544,14 @@ class GPTQConfig(QuantizeAlgorithmConfig): """ method: Literal["gptq"] = ModeloptField("gptq") - percdamp: float | None = ModeloptField( + percdamp: float = ModeloptField( default=0.01, gt=0.0, le=1.0, title="Percentage damping factor.", description="The percentage of average Hessian diagonal used for damping.", ) - block_size: int | None = ModeloptField( + block_size: int = ModeloptField( default=128, title="Block size for GPTQ weight update.", description="""The block size for GPTQ weight update, which must be a multiple of the diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 28d697ddee..7c9f47dca1 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -39,12 +39,14 @@ from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import ( disable_calib, + disabled_weight_quantizers, enable_fake_quant, enable_quant, enable_weight_access_and_writeback, is_quantized_column_parallel_linear, is_quantized_linear, is_quantized_row_parallel_linear, + promote_nvfp4_static_quantizers, quantizer_attr_names, reduce_amax, weight_attr_names, @@ -1647,33 +1649,6 @@ def _layer_forward_loop(m, _inputs=layer_inputs): print_rank_0("Sequential calibration completed") -def _promote_nvfp4_static_quantizers(model: nn.Module) -> int: - """Convert eligible TensorQuantizers to NVFP4StaticQuantizer in-place. - - After max calibration sets per-block amax values, NVFP4 static quantizers - need to be promoted so they use the two-level scaling path (global amax + - per-block amax) instead of the generic E4M3 path. - - Returns the number of quantizers converted. - """ - converted = 0 - for _name, module in list(model.named_modules()): - if isinstance(module, TensorQuantizer) and not module._disabled: - if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): - is_nvfp4_static = ( - module.is_static_block_quant - and module._num_bits == (2, 1) - and module._block_sizes is not None - and module._block_sizes.get("scale_bits") == (4, 3) - ) - if is_nvfp4_static: - initial_amax = module._amax.clone().detach() - global_amax = reduce_amax(initial_amax, axis=None) - NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) - converted += 1 - return converted - - @torch.no_grad() def gptq( model: nn.Module, @@ -1750,9 +1725,8 @@ def hessian_forward(self, input, *args, **kwargs): hessian_input, gptq_helper.hessian, gptq_helper.n_samples ) - self.weight_quantizer.disable() out = self._forward_no_gptq_hessian(input, *args, **kwargs) - self.weight_quantizer.enable() + return out bind_forward_method(self.module, hessian_forward, self.CACHE_NAME) @@ -1761,6 +1735,12 @@ def cleanup(self): """Unpatch the module's forward method.""" unpatch_forward_method(self.module, self.CACHE_NAME) + def free(self): + """Release Hessian and working tensors to reclaim memory.""" + self.hessian = None + self.weight = None + self.h_inv = None + def update_weights(self, block_size, percdamp): """Run GPTQ blockwise weight update on this module. @@ -1822,6 +1802,14 @@ def _blockwise_update(self, block_size): "_blockwise_update called before _prepare_hessian_inverse()" ) quantizer = self.module.weight_quantizer + block_sizes = getattr(quantizer, "block_sizes", None) + if block_sizes is not None: + group_size = block_sizes.get(-1) + if group_size is not None and block_size % group_size != 0: + raise ValueError( + f"GPTQ block_size ({block_size}) must be divisible by the quantizer" + f" group_size ({group_size})" + ) num_cols = self.weight.shape[1] for block_start in range(0, num_cols, block_size): @@ -1858,7 +1846,7 @@ def _print_mse_error(self, hessian): total_start = time.time() max_calibrate(model, forward_loop=forward_loop) - _promote_nvfp4_static_quantizers(model) + promote_nvfp4_static_quantizers(model) quantized_layers = [ (n, m) @@ -1874,7 +1862,9 @@ def _print_mse_error(self, hessian): handle.setup() print_rank_0(f"Computing Hessians for {len(gptq_handles)} linear layers...") - forward_loop(model) + + with disabled_weight_quantizers(model): + forward_loop(model) for handle in gptq_handles.values(): handle.cleanup() @@ -1882,6 +1872,7 @@ def _print_mse_error(self, hessian): print_rank_0("Updating weights using GPTQ algorithm...") for handle in gptq_handles.values(): handle.update_weights(block_size, percdamp) + handle.free() del gptq_handles if torch.cuda.is_available(): diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 5a2fe37ad5..7793baa22f 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -28,6 +28,7 @@ from torch.distributed.tensor import Replicate from modelopt.torch.quantization.config import QuantizerCfgEntry +from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer from modelopt.torch.utils import get_unwrapped_name, print_rank_0 if TYPE_CHECKING: @@ -137,6 +138,33 @@ def convert_quantization_axis_to_reduce_axis(input, axis): return reduce_axis +def promote_nvfp4_static_quantizers(model: nn.Module) -> int: + """Convert eligible TensorQuantizers to NVFP4StaticQuantizer in-place. + + After max calibration sets per-block amax values, NVFP4 static quantizers + need to be promoted so they use the two-level scaling path (global amax + + per-block amax) instead of the generic E4M3 path. + + Returns the number of quantizers converted. + """ + converted = 0 + for _name, module in list(model.named_modules()): + if isinstance(module, TensorQuantizer) and not module._disabled: + if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): + is_nvfp4_static = ( + module.is_static_block_quant + and module._num_bits == (2, 1) + and module._block_sizes is not None + and module._block_sizes.get("scale_bits") == (4, 3) + ) + if is_nvfp4_static: + initial_amax = module._amax.clone().detach() + global_amax = reduce_amax(initial_amax, axis=None) + NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) + converted += 1 + return converted + + @torch.no_grad() def reduce_amax(input, axis=None, keepdims=True, squeeze_scalar=True): """Compute the absolute maximum value of a tensor. @@ -716,6 +744,21 @@ def disable_calib(quantizer): quantizer._if_calib = original_if_calib +@contextmanager +def disabled_weight_quantizers(model: nn.Module): + """Disable weight quantizers during hessian collection.""" + disabled_modules = [] + for module in model.modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + module.weight_quantizer.disable() + disabled_modules.append(module) + try: + yield + finally: + for module in disabled_modules: + module.weight_quantizer.enable() + + @contextmanager def fsdp2_aware_weight_update(root_model, modules_to_update, reshard=True): """Context manager to update the FSDPParam list if an update is made to a submodule of an FSDPModule. diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index 45f17833fd..f7ef02de1d 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -204,26 +204,10 @@ def test_gptq_export_roundtrip(): num_mismatched = (diff > 1e-3).sum().item() total_elements = diff.numel() - print("\n--- Diff Stats ---") - print(f" Max diff: {max_diff}") - print(f" Mean diff: {diff.mean().item()}") - print(f" Median diff: {diff.median().item()}") - print(f" Std diff: {diff.std().item()}") - print( - f" Mismatched (>1e-3): {num_mismatched}/{total_elements} " - f"({100 * num_mismatched / total_elements:.2f}%)" - ) - print( - f" Max diff at [{max_diff_row}, {max_diff_col}]: " - f"deq={deq_weight[max_diff_row, max_diff_col].item()}, " - f"qdq_ref={qdq_ref[max_diff_row, max_diff_col].item()}" - ) - assert torch.allclose(deq_weight, qdq_ref.to(torch.bfloat16), atol=1e-2), ( f"Dequantized weight does not match QDQ reference. " - f"Max diff: {max_diff} at [{max_diff_row}, {max_diff_col}] " - f"(deq={deq_weight[max_diff_row, max_diff_col].item()}, " - f"qdq_ref={qdq_ref[max_diff_row, max_diff_col].item()})" + f"Max diff: {max_diff} at [{max_diff_row}, {max_diff_col}], " + f"mismatched (>1e-3): {num_mismatched}/{total_elements}" ) From a95bb7776955e0f7b1ecd49f8182990fc552d9dc Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Wed, 25 Mar 2026 18:21:03 +0000 Subject: [PATCH 40/52] fixed circular import issue Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 30 +++++++++++++++++-- .../torch/quantization/utils/core_utils.py | 28 ----------------- 2 files changed, 28 insertions(+), 30 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 7c9f47dca1..f1c1aef449 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -46,7 +46,6 @@ is_quantized_column_parallel_linear, is_quantized_linear, is_quantized_row_parallel_linear, - promote_nvfp4_static_quantizers, quantizer_attr_names, reduce_amax, weight_attr_names, @@ -1649,6 +1648,33 @@ def _layer_forward_loop(m, _inputs=layer_inputs): print_rank_0("Sequential calibration completed") +def _promote_nvfp4_static_quantizers(model: nn.Module) -> int: + """Convert eligible TensorQuantizers to NVFP4StaticQuantizer in-place. + + After max calibration sets per-block amax values, NVFP4 static quantizers + need to be promoted so they use the two-level scaling path (global amax + + per-block amax) instead of the generic E4M3 path. + + Returns the number of quantizers converted. + """ + converted = 0 + for _name, module in list(model.named_modules()): + if isinstance(module, TensorQuantizer) and not module._disabled: + if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): + is_nvfp4_static = ( + module.is_static_block_quant + and module._num_bits == (2, 1) + and module._block_sizes is not None + and module._block_sizes.get("scale_bits") == (4, 3) + ) + if is_nvfp4_static: + initial_amax = module._amax.clone().detach() + global_amax = reduce_amax(initial_amax, axis=None) + NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) + converted += 1 + return converted + + @torch.no_grad() def gptq( model: nn.Module, @@ -1846,7 +1872,7 @@ def _print_mse_error(self, hessian): total_start = time.time() max_calibrate(model, forward_loop=forward_loop) - promote_nvfp4_static_quantizers(model) + _promote_nvfp4_static_quantizers(model) quantized_layers = [ (n, m) diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 7793baa22f..0590c285f8 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -28,7 +28,6 @@ from torch.distributed.tensor import Replicate from modelopt.torch.quantization.config import QuantizerCfgEntry -from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer from modelopt.torch.utils import get_unwrapped_name, print_rank_0 if TYPE_CHECKING: @@ -138,33 +137,6 @@ def convert_quantization_axis_to_reduce_axis(input, axis): return reduce_axis -def promote_nvfp4_static_quantizers(model: nn.Module) -> int: - """Convert eligible TensorQuantizers to NVFP4StaticQuantizer in-place. - - After max calibration sets per-block amax values, NVFP4 static quantizers - need to be promoted so they use the two-level scaling path (global amax + - per-block amax) instead of the generic E4M3 path. - - Returns the number of quantizers converted. - """ - converted = 0 - for _name, module in list(model.named_modules()): - if isinstance(module, TensorQuantizer) and not module._disabled: - if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): - is_nvfp4_static = ( - module.is_static_block_quant - and module._num_bits == (2, 1) - and module._block_sizes is not None - and module._block_sizes.get("scale_bits") == (4, 3) - ) - if is_nvfp4_static: - initial_amax = module._amax.clone().detach() - global_amax = reduce_amax(initial_amax, axis=None) - NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) - converted += 1 - return converted - - @torch.no_grad() def reduce_amax(input, axis=None, keepdims=True, squeeze_scalar=True): """Compute the absolute maximum value of a tensor. From b64b7d3752e8816d94c7aeb93244e32ca16384e7 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 6 Apr 2026 18:25:32 +0000 Subject: [PATCH 41/52] refactor for MIT license Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 191 +------------- .../torch/quantization/utils/calib_utils.py | 248 ++++++++++++++++++ tests/gpu/torch/quantization/test_gptq.py | 3 +- 3 files changed, 251 insertions(+), 191 deletions(-) create mode 100644 modelopt/torch/quantization/utils/calib_utils.py diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index f1c1aef449..f016f0b6ae 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -32,7 +32,6 @@ from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method -from modelopt.torch.utils.perf import get_used_gpu_mem_fraction from .calib import MseCalibrator, NVFP4MSECalibrator from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context @@ -50,6 +49,7 @@ reduce_amax, weight_attr_names, ) +from .utils.calib_utils import GPTQHelper __all__ = [ "awq", @@ -1571,33 +1571,6 @@ def postprocess(module, name): max_calibrate(model, forward_loop) -def update_hessian(input, hessian, n_samples): - """Update hessian matrix with new input samples using incremental formula. - - Args: - input: Input tensor (batch_size, ..., features) - hessian: Current Hessian matrix to update in-place - n_samples: Number of samples already processed - Returns: - Tuple of (updated_hessian, new_sample_count) - - Note: input must be non-empty (batch_size > 0); a zero-sized input causes division by zero. - """ - # Flatten to 2D (total_tokens, features) first, so batch_size counts tokens - input_flat = input.reshape(-1, input.shape[-1]).t().float() - batch_size = input_flat.shape[1] - - # Incremental averaging: scale down old hessian - hessian *= n_samples / (n_samples + batch_size) - n_samples += batch_size - - # Compute outer product: H += (2/n_samples) * X @ X^T - scaled_input = math.sqrt(2 / n_samples) * input_flat - hessian.add_((scaled_input @ scaled_input.t()).to(hessian.device)) - - return hessian, n_samples - - @torch.no_grad() def sequential_calibrate( model: nn.Module, @@ -1707,168 +1680,6 @@ def gptq( percdamp: Percentage of avg Hessian diagonal for damping (default: 0.01). block_size: Block size for GPTQ weight update. """ - - class GPTQHelper: - """Encapsulates per-module GPTQ state and operations. - - Owns the Hessian, patches the forward during collection, and contains - the blockwise weight-update logic. - - Instance attributes set during ``__init__``: - module, name, hessian, n_samples - - Instance attributes set during ``update_weights``: - weight: float working copy of module weights (mutated in-place by update methods) - h_inv: upper-triangular Cholesky factor of the damped inverse Hessian - """ - - CACHE_NAME = "_forward_no_gptq_hessian" - - def __init__(self, module, name, offload_to_cpu=False): - self.module = module - self.name = name - in_features = module.weight.shape[-1] - device = module.weight.device - if offload_to_cpu and get_used_gpu_mem_fraction(device) > 0.65: - device = "cpu" - self.hessian = torch.zeros(in_features, in_features, dtype=torch.float32, device=device) - self.n_samples = 0 - # Set by update_weights(); listed here for documentation. - self.weight: torch.Tensor | None = None - self.h_inv: torch.Tensor | None = None - - def setup(self): - """Patch the module's forward to accumulate Hessian during the collection pass.""" - gptq_helper = self - - def hessian_forward(self, input, *args, **kwargs): - inp = input.to_local() if hasattr(input, "to_local") else input - if self.input_quantizer is not None and self.input_quantizer.is_enabled: - hessian_input = self.input_quantizer(inp) - else: - hessian_input = inp - gptq_helper.hessian, gptq_helper.n_samples = update_hessian( - hessian_input, gptq_helper.hessian, gptq_helper.n_samples - ) - - out = self._forward_no_gptq_hessian(input, *args, **kwargs) - - return out - - bind_forward_method(self.module, hessian_forward, self.CACHE_NAME) - - def cleanup(self): - """Unpatch the module's forward method.""" - unpatch_forward_method(self.module, self.CACHE_NAME) - - def free(self): - """Release Hessian and working tensors to reclaim memory.""" - self.hessian = None - self.weight = None - self.h_inv = None - - def update_weights(self, block_size, percdamp): - """Run GPTQ blockwise weight update on this module. - - Populates ``self.weight`` and ``self.h_inv``, runs the blockwise update, - logs MSE, and writes the result back to the module. - """ - hessian = self.hessian.to(self.module.weight.device) - self.weight = self.module.weight.data.float().clone() - self._prepare_hessian_inverse(hessian, percdamp) - - self._blockwise_update(block_size) - - self._print_mse_error(hessian) - self.module.weight.data = self.weight.reshape(self.module.weight.shape).to( - self.module.weight.data.dtype - ) - - # ------------------------------------------------------------------ - # Quantize helpers — all read from self.module, self.weight, self.h_inv - # ------------------------------------------------------------------ - - def _prepare_hessian_inverse(self, hessian, percdamp): - """Compute damped inverse Hessian and store as ``self.h_inv``. - - Dead-neuron columns (all-zero in ``self.weight``) are zeroed in the - Hessian before inversion, matching the FP-Quant reference: - https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L200 - """ - assert self.weight is not None, ( - "_prepare_hessian_inverse called before update_weights()" - ) - h = hessian.clone() - zero_cols = torch.nonzero(self.weight.eq(0).all(dim=0)).unsqueeze(-1) - - h[zero_cols, :] = 0 - h[:, zero_cols] = 0 - h[zero_cols, zero_cols] = 1 - - damp = percdamp * torch.mean(torch.diag(h)) - diag_indices = torch.arange(h.shape[0], device=h.device) - h[diag_indices, diag_indices] += damp - - try: - h = torch.cholesky_inverse(torch.linalg.cholesky(h)) - self.h_inv = torch.linalg.cholesky(h, upper=True) - except (RuntimeError, torch.linalg.LinAlgError): - print_rank_0("Warning: Hessian is not positive definite, using identity matrix") - self.h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype) - - def _blockwise_update(self, block_size): - """Column-wise GPTQ update using full-matrix QDQ. - - For each column, quantizes the full weight matrix via the quantizer and - extracts the quantized column. This is the standard GPTQ approach. - - Reads/writes ``self.weight`` and ``self.h_inv`` in-place. - """ - assert self.weight is not None and self.h_inv is not None, ( - "_blockwise_update called before _prepare_hessian_inverse()" - ) - quantizer = self.module.weight_quantizer - block_sizes = getattr(quantizer, "block_sizes", None) - if block_sizes is not None: - group_size = block_sizes.get(-1) - if group_size is not None and block_size % group_size != 0: - raise ValueError( - f"GPTQ block_size ({block_size}) must be divisible by the quantizer" - f" group_size ({group_size})" - ) - num_cols = self.weight.shape[1] - - for block_start in range(0, num_cols, block_size): - block_end = min(block_start + block_size, num_cols) - n_cols_blk = block_end - block_start - h_inv_cho_blk = self.h_inv[block_start:block_end, block_start:block_end] - - wblk = self.weight.clone() - errs = torch.zeros_like(wblk[:, block_start:block_end]) - - for i in range(n_cols_blk): - w_ci = wblk[:, block_start + i] - d = h_inv_cho_blk[i, i] - qdq = quantizer(wblk) - self.weight[:, block_start + i] = qdq[:, block_start + i] - err = (w_ci - qdq[:, block_start + i]) / d - wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) - errs[:, i] = err - - self.weight[:, block_end:].addmm_( - errs, self.h_inv[block_start:block_end, block_end:], alpha=-1 - ) - - def _print_mse_error(self, hessian): - """Log Hessian-weighted relative MSE between ``self.weight`` and original weights.""" - w_orig = self.module.weight.float() - delta = self.weight - w_orig - mse = (delta).mm(hessian).mul(delta).mean() / ( - w_orig.mm(hessian).mul(w_orig).mean() + 1e-6 - ) - suffix = f", n_hessian_samples: {self.n_samples}" if self.n_samples else "" - print_rank_0(f"[{self.name}] Relative MSE error: {mse.item():.2e}{suffix}") - total_start = time.time() max_calibrate(model, forward_loop=forward_loop) diff --git a/modelopt/torch/quantization/utils/calib_utils.py b/modelopt/torch/quantization/utils/calib_utils.py new file mode 100644 index 0000000000..e35a624651 --- /dev/null +++ b/modelopt/torch/quantization/utils/calib_utils.py @@ -0,0 +1,248 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from https://github.com/IST-DASLab/FP-Quant/blob/d2e3092/src/quantization/gptq.py +# with minor modifications to the original forms to accommodate minor architectural differences +# to be reused in the Model-Optimizer pipeline. +# Copyright (c) Andrei Panferov +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPTQ helper and Hessian utilities for calibration.""" + +import math + +import torch + +from modelopt.torch.utils import print_rank_0 +from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method +from modelopt.torch.utils.perf import get_used_gpu_mem_fraction + + +def update_hessian(input, hessian, n_samples): + """Update hessian matrix with new input samples using incremental formula. + + Args: + input: Input tensor (batch_size, ..., features) + hessian: Current Hessian matrix to update in-place + n_samples: Number of samples already processed + Returns: + Tuple of (updated_hessian, new_sample_count) + + Note: input must be non-empty (batch_size > 0); a zero-sized input causes division by zero. + """ + # Flatten to 2D (total_tokens, features) first, so batch_size counts tokens + input_flat = input.reshape(-1, input.shape[-1]).t().float() + batch_size = input_flat.shape[1] + + # Incremental averaging: scale down old hessian + hessian *= n_samples / (n_samples + batch_size) + n_samples += batch_size + + # Compute outer product: H += (2/n_samples) * X @ X^T + scaled_input = math.sqrt(2 / n_samples) * input_flat + hessian.add_((scaled_input @ scaled_input.t()).to(hessian.device)) + + return hessian, n_samples + + +class GPTQHelper: + """Encapsulates per-module GPTQ state and operations. + + Owns the Hessian, patches the forward during collection, and contains + the blockwise weight-update logic. + + Instance attributes set during ``__init__``: + module, name, hessian, n_samples + + Instance attributes set during ``update_weights``: + weight: float working copy of module weights (mutated in-place by update methods) + h_inv: upper-triangular Cholesky factor of the damped inverse Hessian + """ + + CACHE_NAME = "_forward_no_gptq_hessian" + + def __init__(self, module, name, offload_to_cpu=False): + """Initialize GPTQHelper with module state and Hessian storage.""" + self.module = module + self.name = name + in_features = module.weight.shape[-1] + device = module.weight.device + if offload_to_cpu and get_used_gpu_mem_fraction(device) > 0.65: + device = "cpu" + self.hessian = torch.zeros(in_features, in_features, dtype=torch.float32, device=device) + self.n_samples = 0 + # Set by update_weights(); listed here for documentation. + self.weight: torch.Tensor | None = None + self.h_inv: torch.Tensor | None = None + + def setup(self): + """Patch the module's forward to accumulate Hessian during the collection pass.""" + gptq_helper = self + + def hessian_forward(self, input, *args, **kwargs): + inp = input.to_local() if hasattr(input, "to_local") else input + if self.input_quantizer is not None and self.input_quantizer.is_enabled: + hessian_input = self.input_quantizer(inp) + else: + hessian_input = inp + gptq_helper.hessian, gptq_helper.n_samples = update_hessian( + hessian_input, gptq_helper.hessian, gptq_helper.n_samples + ) + + out = self._forward_no_gptq_hessian(input, *args, **kwargs) + + return out + + bind_forward_method(self.module, hessian_forward, self.CACHE_NAME) + + def cleanup(self): + """Unpatch the module's forward method.""" + unpatch_forward_method(self.module, self.CACHE_NAME) + + def free(self): + """Release Hessian and working tensors to reclaim memory.""" + self.hessian = None + self.weight = None + self.h_inv = None + + def update_weights(self, block_size, percdamp): + """Run GPTQ blockwise weight update on this module. + + Populates ``self.weight`` and ``self.h_inv``, runs the blockwise update, + logs MSE, and writes the result back to the module. + """ + hessian = self.hessian.to(self.module.weight.device) + self.weight = self.module.weight.data.float().clone() + self._prepare_hessian_inverse(hessian, percdamp) + + self._blockwise_update(block_size) + + self._print_mse_error(hessian) + self.module.weight.data = self.weight.reshape(self.module.weight.shape).to( + self.module.weight.data.dtype + ) + + # ------------------------------------------------------------------ + # Quantize helpers — all read from self.module, self.weight, self.h_inv + # ------------------------------------------------------------------ + + def _prepare_hessian_inverse(self, hessian, percdamp): + """Compute damped inverse Hessian and store as ``self.h_inv``. + + Dead-neuron columns (all-zero in ``self.weight``) are zeroed in the + Hessian before inversion, matching the FP-Quant reference: + https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L200 + """ + assert self.weight is not None, "_prepare_hessian_inverse called before update_weights()" + h = hessian.clone() + zero_cols = torch.nonzero(self.weight.eq(0).all(dim=0)).unsqueeze(-1) + + h[zero_cols, :] = 0 + h[:, zero_cols] = 0 + h[zero_cols, zero_cols] = 1 + + damp = percdamp * torch.mean(torch.diag(h)) + diag_indices = torch.arange(h.shape[0], device=h.device) + h[diag_indices, diag_indices] += damp + + try: + h = torch.cholesky_inverse(torch.linalg.cholesky(h)) + self.h_inv = torch.linalg.cholesky(h, upper=True) + except (RuntimeError, torch.linalg.LinAlgError): + print_rank_0("Warning: Hessian is not positive definite, using identity matrix") + self.h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype) + + def _blockwise_update(self, block_size): + """Column-wise GPTQ update using full-matrix QDQ. + + For each column, quantizes the full weight matrix via the quantizer and + extracts the quantized column. This is the standard GPTQ approach. + + Reads/writes ``self.weight`` and ``self.h_inv`` in-place. + """ + assert self.weight is not None and self.h_inv is not None, ( + "_blockwise_update called before _prepare_hessian_inverse()" + ) + quantizer = self.module.weight_quantizer + block_sizes = getattr(quantizer, "block_sizes", None) + if block_sizes is not None: + group_size = block_sizes.get(-1) + if group_size is not None and block_size % group_size != 0: + raise ValueError( + f"GPTQ block_size ({block_size}) must be divisible by the quantizer" + f" group_size ({group_size})" + ) + num_cols = self.weight.shape[1] + + for block_start in range(0, num_cols, block_size): + block_end = min(block_start + block_size, num_cols) + n_cols_blk = block_end - block_start + h_inv_cho_blk = self.h_inv[block_start:block_end, block_start:block_end] + + wblk = self.weight.clone() + errs = torch.zeros_like(wblk[:, block_start:block_end]) + + for i in range(n_cols_blk): + w_ci = wblk[:, block_start + i] + d = h_inv_cho_blk[i, i] + qdq = quantizer(wblk) + self.weight[:, block_start + i] = qdq[:, block_start + i] + err = (w_ci - qdq[:, block_start + i]) / d + wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) + errs[:, i] = err + + self.weight[:, block_end:].addmm_( + errs, self.h_inv[block_start:block_end, block_end:], alpha=-1 + ) + + def _print_mse_error(self, hessian): + """Log Hessian-weighted relative MSE between ``self.weight`` and original weights.""" + w_orig = self.module.weight.float() + delta = self.weight - w_orig + mse = (delta).mm(hessian).mul(delta).mean() / (w_orig.mm(hessian).mul(w_orig).mean() + 1e-6) + suffix = f", n_hessian_samples: {self.n_samples}" if self.n_samples else "" + print_rank_0(f"[{self.name}] Relative MSE error: {mse.item():.2e}{suffix}") diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index f7ef02de1d..021ebd4049 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -21,8 +21,9 @@ import modelopt.torch.quantization as mtq from modelopt.torch.export.unified_export_hf import _export_quantized_weight -from modelopt.torch.quantization.model_calib import gptq, update_hessian +from modelopt.torch.quantization.model_calib import gptq from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor +from modelopt.torch.quantization.utils.calib_utils import update_hessian from modelopt.torch.utils.dataset_utils import create_forward_loop, get_dataset_dataloader RAND_SEED = 42 From 7de3cc0b0469ec2c4b51b0a217685bdb7dc82b12 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 6 Apr 2026 21:35:57 +0000 Subject: [PATCH 42/52] PR review addressed Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- .pre-commit-config.yaml | 1 + modelopt/torch/quantization/model_calib.py | 8 ++--- .../torch/quantization/utils/calib_utils.py | 15 -------- .../torch/quantization/utils/core_utils.py | 15 -------- tests/gpu/torch/quantization/test_gptq.py | 35 ++++--------------- 5 files changed, 12 insertions(+), 62 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2c21627234..dc464616f0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -75,6 +75,7 @@ repos: # Instead, we should manually add the license header to those files *after* the original header. exclude: > (?x)^( + modelopt/torch/quantization/utils/calib_utils.py| modelopt/onnx/quantization/operators.py| modelopt/onnx/quantization/ort_patching.py| modelopt/torch/_deploy/utils/onnx_utils.py| diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index f016f0b6ae..ed329dc24a 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -38,7 +38,6 @@ from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import ( disable_calib, - disabled_weight_quantizers, enable_fake_quant, enable_quant, enable_weight_access_and_writeback, @@ -1652,7 +1651,7 @@ def _promote_nvfp4_static_quantizers(model: nn.Module) -> int: def gptq( model: nn.Module, forward_loop: ForwardLoop, - percdamp: float = 0.01, + perc_damp: float = 0.01, block_size: int = 128, ): """GPTQ quantization. @@ -1682,6 +1681,7 @@ def gptq( """ total_start = time.time() + # TODO: Add support for other scale setting strateiges like weight-mse or local-hessian max_calibrate(model, forward_loop=forward_loop) _promote_nvfp4_static_quantizers(model) @@ -1700,7 +1700,7 @@ def gptq( print_rank_0(f"Computing Hessians for {len(gptq_handles)} linear layers...") - with disabled_weight_quantizers(model): + with set_quantizer_by_cfg_context(model, {"*weight_quantizer": {"enable": False}}): forward_loop(model) for handle in gptq_handles.values(): @@ -1708,7 +1708,7 @@ def gptq( print_rank_0("Updating weights using GPTQ algorithm...") for handle in gptq_handles.values(): - handle.update_weights(block_size, percdamp) + handle.update_weights(block_size, perc_damp) handle.free() del gptq_handles diff --git a/modelopt/torch/quantization/utils/calib_utils.py b/modelopt/torch/quantization/utils/calib_utils.py index e35a624651..d5aacbcfb9 100644 --- a/modelopt/torch/quantization/utils/calib_utils.py +++ b/modelopt/torch/quantization/utils/calib_utils.py @@ -1,18 +1,3 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - # Adapted from https://github.com/IST-DASLab/FP-Quant/blob/d2e3092/src/quantization/gptq.py # with minor modifications to the original forms to accommodate minor architectural differences # to be reused in the Model-Optimizer pipeline. diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 0590c285f8..5a2fe37ad5 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -716,21 +716,6 @@ def disable_calib(quantizer): quantizer._if_calib = original_if_calib -@contextmanager -def disabled_weight_quantizers(model: nn.Module): - """Disable weight quantizers during hessian collection.""" - disabled_modules = [] - for module in model.modules(): - if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - module.weight_quantizer.disable() - disabled_modules.append(module) - try: - yield - finally: - for module in disabled_modules: - module.weight_quantizer.enable() - - @contextmanager def fsdp2_aware_weight_update(root_model, modules_to_update, reshard=True): """Context manager to update the FSDPParam list if an update is made to a submodule of an FSDPModule. diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index 021ebd4049..8d5d0f3db0 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -17,7 +17,8 @@ import pytest import torch -from transformers import AutoModelForCausalLM, AutoTokenizer +from _test_utils.torch.transformers_models import get_tiny_llama +from transformers import AutoTokenizer import modelopt.torch.quantization as mtq from modelopt.torch.export.unified_export_hf import _export_quantized_weight @@ -216,17 +217,12 @@ def test_gptq_export_roundtrip(): "quant_cfg", [mtq.NVFP4_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG] ) def test_gptq_e2e_flow(quant_cfg): - model = AutoModelForCausalLM.from_pretrained( - "TinyLlama/TinyLlama-1.1B-Chat-v1.0", device_map="auto" - ) - tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + model = get_tiny_llama(vocab_size=tokenizer.vocab_size).to("cuda") - # can't set attribute 'pad_token' for "" - # We skip this step for Nemo models - if tokenizer.pad_token != "" or tokenizer.pad_token is None: + if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - # Left padding usually provides better calibration result. tokenizer.padding_side = "left" assert tokenizer.pad_token is not None, "Pad token cannot be set!" @@ -234,31 +230,14 @@ def test_gptq_e2e_flow(quant_cfg): quant_cfg = copy.deepcopy(quant_cfg) quant_cfg["algorithm"] = {"method": "gptq", "use_sequential": True} - # Define quantizer/dataloader calib_dataloader = get_dataset_dataloader( dataset_name="cnn_dailymail", tokenizer=tokenizer, - batch_size=32, - num_samples=512, + batch_size=2, + num_samples=8, device="cuda", include_labels=False, ) - # Only run single sample for preview - prompt = "Where is New York city?" - input_ids = tokenizer(prompt, return_tensors="pt") - print(f"Input ids: {input_ids}") - generated_ids_before_ptq = model.generate( - input_ids["input_ids"].to("cuda"), max_new_tokens=100, do_sample=False, temperature=0.0 - ) - print( - f"Generated ids before quantization: {tokenizer.decode(generated_ids_before_ptq[0], skip_special_tokens=True)}" - ) calibrate_loop = create_forward_loop(dataloader=calib_dataloader) model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - generated_ids_after_ptq = model.generate( - input_ids["input_ids"].to("cuda"), max_new_tokens=100, do_sample=False, temperature=0.0 - ) - print( - f"Generated ids after quantization: {tokenizer.decode(generated_ids_after_ptq[0], skip_special_tokens=True)}" - ) From 748b6bf9d5123113ff93cc905a54263f3613e244 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Mon, 6 Apr 2026 22:12:57 +0000 Subject: [PATCH 43/52] moved promote_nvfp4_static_quantizers to utils Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 30 ++----------------- .../torch/quantization/utils/core_utils.py | 29 ++++++++++++++++++ 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index ed329dc24a..e0a9031146 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -44,6 +44,7 @@ is_quantized_column_parallel_linear, is_quantized_linear, is_quantized_row_parallel_linear, + promote_nvfp4_static_quantizers, quantizer_attr_names, reduce_amax, weight_attr_names, @@ -1620,33 +1621,6 @@ def _layer_forward_loop(m, _inputs=layer_inputs): print_rank_0("Sequential calibration completed") -def _promote_nvfp4_static_quantizers(model: nn.Module) -> int: - """Convert eligible TensorQuantizers to NVFP4StaticQuantizer in-place. - - After max calibration sets per-block amax values, NVFP4 static quantizers - need to be promoted so they use the two-level scaling path (global amax + - per-block amax) instead of the generic E4M3 path. - - Returns the number of quantizers converted. - """ - converted = 0 - for _name, module in list(model.named_modules()): - if isinstance(module, TensorQuantizer) and not module._disabled: - if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): - is_nvfp4_static = ( - module.is_static_block_quant - and module._num_bits == (2, 1) - and module._block_sizes is not None - and module._block_sizes.get("scale_bits") == (4, 3) - ) - if is_nvfp4_static: - initial_amax = module._amax.clone().detach() - global_amax = reduce_amax(initial_amax, axis=None) - NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) - converted += 1 - return converted - - @torch.no_grad() def gptq( model: nn.Module, @@ -1683,7 +1657,7 @@ def gptq( # TODO: Add support for other scale setting strateiges like weight-mse or local-hessian max_calibrate(model, forward_loop=forward_loop) - _promote_nvfp4_static_quantizers(model) + promote_nvfp4_static_quantizers(model) quantized_layers = [ (n, m) diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 5a2fe37ad5..da5a4d2d69 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -854,3 +854,32 @@ def update_quant_cfg_with_kv_cache_quant( quant_cfg["algorithm"] = "max" print_rank_0(f"Updated quant_cfg with KV cache quantization: {quant_cfg}") return quant_cfg + + +def promote_nvfp4_static_quantizers(model: nn.Module) -> int: + """Convert eligible TensorQuantizers to NVFP4StaticQuantizer in-place. + + After max calibration sets per-block amax values, NVFP4 static quantizers + need to be promoted so they use the two-level scaling path (global amax + + per-block amax) instead of the generic E4M3 path. + + Returns the number of quantizers converted. + """ + from ..nn import NVFP4StaticQuantizer, TensorQuantizer + + converted = 0 + for _name, module in list(model.named_modules()): + if isinstance(module, TensorQuantizer) and not module._disabled: + if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): + is_nvfp4_static = ( + module.is_static_block_quant + and module._num_bits == (2, 1) + and module._block_sizes is not None + and module._block_sizes.get("scale_bits") == (4, 3) + ) + if is_nvfp4_static: + initial_amax = module._amax.clone().detach() + global_amax = reduce_amax(initial_amax, axis=None) + NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) + converted += 1 + return converted From 06ec8b0231f07296dfa5c47c00f36c26301f9557 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 7 Apr 2026 04:59:06 +0000 Subject: [PATCH 44/52] fixed unit tests Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 2 +- modelopt/torch/quantization/model_calib.py | 2 +- .../torch/quantization/utils/calib_utils.py | 8 ++--- tests/gpu/torch/quantization/test_gptq.py | 29 +++++++------------ 4 files changed, 16 insertions(+), 25 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 33d198eef4..9ef36577b2 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1544,7 +1544,7 @@ class GPTQConfig(QuantizeAlgorithmConfig): """ method: Literal["gptq"] = ModeloptField("gptq") - percdamp: float = ModeloptField( + perc_damp: float = ModeloptField( default=0.01, gt=0.0, le=1.0, diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index e0a9031146..ec2b2ada07 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1650,7 +1650,7 @@ def gptq( model: The module to quantize — either the full model or a single decoder layer when invoked by ``sequential_calibrate``. forward_loop: Callable that replays calibration inputs through *model*. - percdamp: Percentage of avg Hessian diagonal for damping (default: 0.01). + perc_damp: Percentage of avg Hessian diagonal for damping (default: 0.01). block_size: Block size for GPTQ weight update. """ total_start = time.time() diff --git a/modelopt/torch/quantization/utils/calib_utils.py b/modelopt/torch/quantization/utils/calib_utils.py index d5aacbcfb9..b1d77677b7 100644 --- a/modelopt/torch/quantization/utils/calib_utils.py +++ b/modelopt/torch/quantization/utils/calib_utils.py @@ -134,7 +134,7 @@ def free(self): self.weight = None self.h_inv = None - def update_weights(self, block_size, percdamp): + def update_weights(self, block_size, perc_damp): """Run GPTQ blockwise weight update on this module. Populates ``self.weight`` and ``self.h_inv``, runs the blockwise update, @@ -142,7 +142,7 @@ def update_weights(self, block_size, percdamp): """ hessian = self.hessian.to(self.module.weight.device) self.weight = self.module.weight.data.float().clone() - self._prepare_hessian_inverse(hessian, percdamp) + self._prepare_hessian_inverse(hessian, perc_damp) self._blockwise_update(block_size) @@ -155,7 +155,7 @@ def update_weights(self, block_size, percdamp): # Quantize helpers — all read from self.module, self.weight, self.h_inv # ------------------------------------------------------------------ - def _prepare_hessian_inverse(self, hessian, percdamp): + def _prepare_hessian_inverse(self, hessian, perc_damp): """Compute damped inverse Hessian and store as ``self.h_inv``. Dead-neuron columns (all-zero in ``self.weight``) are zeroed in the @@ -170,7 +170,7 @@ def _prepare_hessian_inverse(self, hessian, percdamp): h[:, zero_cols] = 0 h[zero_cols, zero_cols] = 1 - damp = percdamp * torch.mean(torch.diag(h)) + damp = perc_damp * torch.mean(torch.diag(h)) diag_indices = torch.arange(h.shape[0], device=h.device) h[diag_indices, diag_indices] += damp diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index 8d5d0f3db0..8867854737 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -117,11 +117,11 @@ def test_update_hessian(): @pytest.mark.parametrize( ("block_size", "dim", "model_weight", "expect_weight_change"), [ - (4, 16, torch.randn(16, 16).to("cuda"), True), # random weight + (16, 128, torch.randn(128, 128).to("cuda"), True), # random weight ( - 4, 16, - torch.ones(16, 16).to("cuda"), + 128, + torch.ones(128, 128).to("cuda"), False, ), # all same weight -> no quantization error -> no GPTQ update ], @@ -142,7 +142,7 @@ def test_gptq_updates(block_size, dim, model_weight, expect_weight_change): model.weight.data = original_weight.clone() # Run GPTQ through the public API - gptq(model, forward_loop=lambda m: m(input_tensor), percdamp=0.1, block_size=block_size) + gptq(model, forward_loop=lambda m: m(input_tensor), perc_damp=0.1, block_size=block_size) if expect_weight_change: # Weight must change as GPTQ updates weights to adjust for quantization error assert not torch.allclose(model.weight.data, q_dq_weight), "Weight should not be equal" @@ -154,12 +154,12 @@ def test_gptq_export_roundtrip(): """Test that GPTQ export + dequantize produces weights matching in-memory QDQ.""" torch.manual_seed(RAND_SEED) dim = 128 - block_size = 4 + block_size = 16 # Step 1: Create a simple linear model and quantize to install NVFP4 quantizers - model = torch.nn.Linear(dim, dim).to("cuda") + model = torch.nn.Linear(dim, dim, dtype=torch.bfloat16).to("cuda") original_weight = model.weight.data.clone() - input_tensor = torch.randn(2, 16, dim).to("cuda") + input_tensor = torch.randn(2, 16, dim, dtype=torch.bfloat16).to("cuda") quant_cfg = mtq.NVFP4_DEFAULT_CFG mtq.quantize(model, quant_cfg, forward_loop=lambda m: m(input_tensor)) @@ -168,7 +168,7 @@ def test_gptq_export_roundtrip(): model.weight.data = original_weight.clone() # Step 2: Perform GPTQ — compute Hessian and update weights - gptq(model, forward_loop=lambda m: m(input_tensor), percdamp=0.1, block_size=block_size) + gptq(model, forward_loop=lambda m: m(input_tensor), perc_damp=0.1, block_size=block_size) # Save the QDQ reference from the quantizer applied to GPTQ'd weights gptq_weight_shape = model.weight.data.shape @@ -198,18 +198,9 @@ def test_gptq_export_roundtrip(): assert deq_weight.shape == qdq_ref.shape, ( f"Shape mismatch: dequantized {deq_weight.shape} vs QDQ ref {qdq_ref.shape}" ) - diff = (deq_weight - qdq_ref.to(torch.bfloat16)).abs() - max_diff = diff.max().item() - max_diff_idx = diff.argmax().item() - max_diff_row = max_diff_idx // deq_weight.shape[1] - max_diff_col = max_diff_idx % deq_weight.shape[1] - num_mismatched = (diff > 1e-3).sum().item() - total_elements = diff.numel() - - assert torch.allclose(deq_weight, qdq_ref.to(torch.bfloat16), atol=1e-2), ( + assert torch.allclose(deq_weight, qdq_ref, atol=1e-2), ( f"Dequantized weight does not match QDQ reference. " - f"Max diff: {max_diff} at [{max_diff_row}, {max_diff_col}], " - f"mismatched (>1e-3): {num_mismatched}/{total_elements}" + f"Max diff: {(deq_weight - qdq_ref).abs().max().item()}" ) From 8feb42f277181d21d9c3862fc368c4731f5aa862 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 7 Apr 2026 05:47:17 +0000 Subject: [PATCH 45/52] updated precision to bfloat16 Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/utils/calib_utils.py | 16 ++++++++-------- tests/gpu/torch/quantization/test_gptq.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/modelopt/torch/quantization/utils/calib_utils.py b/modelopt/torch/quantization/utils/calib_utils.py index b1d77677b7..6775ecbc44 100644 --- a/modelopt/torch/quantization/utils/calib_utils.py +++ b/modelopt/torch/quantization/utils/calib_utils.py @@ -60,7 +60,7 @@ def update_hessian(input, hessian, n_samples): Note: input must be non-empty (batch_size > 0); a zero-sized input causes division by zero. """ # Flatten to 2D (total_tokens, features) first, so batch_size counts tokens - input_flat = input.reshape(-1, input.shape[-1]).t().float() + input_flat = input.reshape(-1, input.shape[-1]).t() batch_size = input_flat.shape[1] # Incremental averaging: scale down old hessian @@ -98,7 +98,9 @@ def __init__(self, module, name, offload_to_cpu=False): device = module.weight.device if offload_to_cpu and get_used_gpu_mem_fraction(device) > 0.65: device = "cpu" - self.hessian = torch.zeros(in_features, in_features, dtype=torch.float32, device=device) + self.hessian = torch.zeros( + in_features, in_features, dtype=module.weight.dtype, device=device + ) self.n_samples = 0 # Set by update_weights(); listed here for documentation. self.weight: torch.Tensor | None = None @@ -141,15 +143,13 @@ def update_weights(self, block_size, perc_damp): logs MSE, and writes the result back to the module. """ hessian = self.hessian.to(self.module.weight.device) - self.weight = self.module.weight.data.float().clone() + self.weight = self.module.weight.data.clone() self._prepare_hessian_inverse(hessian, perc_damp) self._blockwise_update(block_size) self._print_mse_error(hessian) - self.module.weight.data = self.weight.reshape(self.module.weight.shape).to( - self.module.weight.data.dtype - ) + self.module.weight.data = self.weight.reshape(self.module.weight.shape) # ------------------------------------------------------------------ # Quantize helpers — all read from self.module, self.weight, self.h_inv @@ -226,8 +226,8 @@ def _blockwise_update(self, block_size): def _print_mse_error(self, hessian): """Log Hessian-weighted relative MSE between ``self.weight`` and original weights.""" - w_orig = self.module.weight.float() + w_orig = self.module.weight.data delta = self.weight - w_orig - mse = (delta).mm(hessian).mul(delta).mean() / (w_orig.mm(hessian).mul(w_orig).mean() + 1e-6) + mse = delta.mm(hessian).mul(delta).mean() / (w_orig.mm(hessian).mul(w_orig).mean() + 1e-6) suffix = f", n_hessian_samples: {self.n_samples}" if self.n_samples else "" print_rank_0(f"[{self.name}] Relative MSE error: {mse.item():.2e}{suffix}") diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index 8867854737..7bf487e365 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -198,7 +198,7 @@ def test_gptq_export_roundtrip(): assert deq_weight.shape == qdq_ref.shape, ( f"Shape mismatch: dequantized {deq_weight.shape} vs QDQ ref {qdq_ref.shape}" ) - assert torch.allclose(deq_weight, qdq_ref, atol=1e-2), ( + assert torch.equal(deq_weight, qdq_ref), ( f"Dequantized weight does not match QDQ reference. " f"Max diff: {(deq_weight - qdq_ref).abs().max().item()}" ) From 446c8a7eab871af181abcd199985194c35c8f33f Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 7 Apr 2026 17:36:40 +0000 Subject: [PATCH 46/52] Revert "updated precision to bfloat16" This reverts commit 89c8db093c5fbb40cc7dbf17a766c98603bc5963. Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/utils/calib_utils.py | 16 ++++++++-------- tests/gpu/torch/quantization/test_gptq.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/modelopt/torch/quantization/utils/calib_utils.py b/modelopt/torch/quantization/utils/calib_utils.py index 6775ecbc44..b1d77677b7 100644 --- a/modelopt/torch/quantization/utils/calib_utils.py +++ b/modelopt/torch/quantization/utils/calib_utils.py @@ -60,7 +60,7 @@ def update_hessian(input, hessian, n_samples): Note: input must be non-empty (batch_size > 0); a zero-sized input causes division by zero. """ # Flatten to 2D (total_tokens, features) first, so batch_size counts tokens - input_flat = input.reshape(-1, input.shape[-1]).t() + input_flat = input.reshape(-1, input.shape[-1]).t().float() batch_size = input_flat.shape[1] # Incremental averaging: scale down old hessian @@ -98,9 +98,7 @@ def __init__(self, module, name, offload_to_cpu=False): device = module.weight.device if offload_to_cpu and get_used_gpu_mem_fraction(device) > 0.65: device = "cpu" - self.hessian = torch.zeros( - in_features, in_features, dtype=module.weight.dtype, device=device - ) + self.hessian = torch.zeros(in_features, in_features, dtype=torch.float32, device=device) self.n_samples = 0 # Set by update_weights(); listed here for documentation. self.weight: torch.Tensor | None = None @@ -143,13 +141,15 @@ def update_weights(self, block_size, perc_damp): logs MSE, and writes the result back to the module. """ hessian = self.hessian.to(self.module.weight.device) - self.weight = self.module.weight.data.clone() + self.weight = self.module.weight.data.float().clone() self._prepare_hessian_inverse(hessian, perc_damp) self._blockwise_update(block_size) self._print_mse_error(hessian) - self.module.weight.data = self.weight.reshape(self.module.weight.shape) + self.module.weight.data = self.weight.reshape(self.module.weight.shape).to( + self.module.weight.data.dtype + ) # ------------------------------------------------------------------ # Quantize helpers — all read from self.module, self.weight, self.h_inv @@ -226,8 +226,8 @@ def _blockwise_update(self, block_size): def _print_mse_error(self, hessian): """Log Hessian-weighted relative MSE between ``self.weight`` and original weights.""" - w_orig = self.module.weight.data + w_orig = self.module.weight.float() delta = self.weight - w_orig - mse = delta.mm(hessian).mul(delta).mean() / (w_orig.mm(hessian).mul(w_orig).mean() + 1e-6) + mse = (delta).mm(hessian).mul(delta).mean() / (w_orig.mm(hessian).mul(w_orig).mean() + 1e-6) suffix = f", n_hessian_samples: {self.n_samples}" if self.n_samples else "" print_rank_0(f"[{self.name}] Relative MSE error: {mse.item():.2e}{suffix}") diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index 7bf487e365..8867854737 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -198,7 +198,7 @@ def test_gptq_export_roundtrip(): assert deq_weight.shape == qdq_ref.shape, ( f"Shape mismatch: dequantized {deq_weight.shape} vs QDQ ref {qdq_ref.shape}" ) - assert torch.equal(deq_weight, qdq_ref), ( + assert torch.allclose(deq_weight, qdq_ref, atol=1e-2), ( f"Dequantized weight does not match QDQ reference. " f"Max diff: {(deq_weight - qdq_ref).abs().max().item()}" ) From 1394a33a7b61f447b0ab42bdf87a4503c5976150 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 7 Apr 2026 18:20:21 +0000 Subject: [PATCH 47/52] rebase fixes Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 40 ++----------------- modelopt/torch/quantization/model_calib.py | 4 +- .../torch/quantization/utils/core_utils.py | 3 +- 3 files changed, 8 insertions(+), 39 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 9ef36577b2..0ad5db2ab7 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1372,34 +1372,6 @@ class LocalHessianCalibConfig(QuantizeAlgorithmConfig): description="If True, module's local Hessian metadata will be kept as a module attribute.", ) -class GPTQConfig(QuantizeAlgorithmConfig): - """The config for GPTQ quantization. - - GPTQ minimizes the layer-wise quantization error by using second-order (Hessian) information - to perform blockwise weight updates that compensate for rounding loss. Layers are quantized - sequentially so that each layer's Hessian is computed from activations that already reflect - the quantization of preceding layers. - - The default values are taken from the official GPTQ implementation: - https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L35 - """ - - method: Literal["gptq"] = ModeloptField("gptq") - percdamp: float | None = ModeloptField( - default=0.01, - gt=0.0, - le=1.0, - title="Percentage damping factor.", - description="The percentage of average Hessian diagonal used for damping.", - ) - block_size: int | None = ModeloptField( - default=128, - title="Block size for GPTQ weight update.", - description="""The block size for GPTQ weight update, which must be a multiple of the - group_size used in the quantization.""", - ) - - class SmoothQuantCalibConfig(QuantizeAlgorithmConfig): """The config for ``smoothquant`` algorithm (SmoothQuant). @@ -1544,26 +1516,22 @@ class GPTQConfig(QuantizeAlgorithmConfig): """ method: Literal["gptq"] = ModeloptField("gptq") - perc_damp: float = ModeloptField( + percdamp: float | None = ModeloptField( default=0.01, gt=0.0, le=1.0, title="Percentage damping factor.", description="The percentage of average Hessian diagonal used for damping.", ) - block_size: int = ModeloptField( + block_size: int | None = ModeloptField( default=128, title="Block size for GPTQ weight update.", description="""The block size for GPTQ weight update, which must be a multiple of the group_size used in the quantization.""", ) -QuantizeQuantCfgType = dict[ - str | Callable, - QuantizerAttributeConfig - | list[QuantizerAttributeConfig] - | dict[str | Callable, QuantizerAttributeConfig | list[QuantizerAttributeConfig]], -] + +QuantizeQuantCfgType = list[QuantizerCfgEntry] _QuantizeAlgoCfgType = str | dict | QuantizeAlgorithmConfig | None diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index ec2b2ada07..7113bc1423 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1674,7 +1674,9 @@ def gptq( print_rank_0(f"Computing Hessians for {len(gptq_handles)} linear layers...") - with set_quantizer_by_cfg_context(model, {"*weight_quantizer": {"enable": False}}): + with set_quantizer_by_cfg_context( + model, [{"quantizer_name": "*weight_quantizer", "enable": False}] + ): forward_loop(model) for handle in gptq_handles.values(): diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index da5a4d2d69..6356b7fc61 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -28,6 +28,7 @@ from torch.distributed.tensor import Replicate from modelopt.torch.quantization.config import QuantizerCfgEntry +from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer from modelopt.torch.utils import get_unwrapped_name, print_rank_0 if TYPE_CHECKING: @@ -865,8 +866,6 @@ def promote_nvfp4_static_quantizers(model: nn.Module) -> int: Returns the number of quantizers converted. """ - from ..nn import NVFP4StaticQuantizer, TensorQuantizer - converted = 0 for _name, module in list(model.named_modules()): if isinstance(module, TensorQuantizer) and not module._disabled: From f0802142cbaecbfc8fb46263352531efa40546fe Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 7 Apr 2026 18:25:12 +0000 Subject: [PATCH 48/52] minor rename Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 4 ++-- modelopt/torch/quantization/config.py | 2 +- modelopt/torch/quantization/mode.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index e2f283e726..bb240ba0cb 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -950,7 +950,7 @@ def quantize_main( else: # mono quantization - + if args.recipe is not None: print(f"Use recipe {args.recipe} for quantization") recipe = load_recipe(args.recipe) @@ -1251,7 +1251,7 @@ def parse_args() -> argparse.Namespace: help="Export as vLLM fake-quant checkpoint (produces vllm_fq_modelopt_state.pth " "for use with vllm_serve_fakequant.py).", ) - + args = parser.parse_args() if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0): parser.error("--moe_calib_experts_ratio must be in the range (0.0, 1.0].") diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 0ad5db2ab7..c7f5636b8d 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1503,7 +1503,7 @@ class SVDQuantConfig(QuantizeAlgorithmConfig): ) -class GPTQConfig(QuantizeAlgorithmConfig): +class GPTQCalibConfig(QuantizeAlgorithmConfig): """The config for GPTQ quantization. GPTQ minimizes the layer-wise quantization error by using second-order (Hessian) information diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 63b3a7c913..c81d5c89c7 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -37,7 +37,7 @@ AWQFullCalibConfig, AWQLiteCalibConfig, CompressConfig, - GPTQConfig, + GPTQCalibConfig, LocalHessianCalibConfig, MaxCalibConfig, MseCalibConfig, @@ -499,6 +499,6 @@ class GPTQModeDescriptor(BaseCalibrateModeDescriptor): @property def config_class(self) -> type[QuantizeAlgorithmConfig]: """Specifies the config class for the mode.""" - return GPTQConfig + return GPTQCalibConfig _calib_func = gptq From d64df14bd7d50c16154fa58c0b712f83620f8177 Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 7 Apr 2026 18:43:22 +0000 Subject: [PATCH 49/52] fixed circular import Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/utils/core_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 6356b7fc61..da899c3419 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -28,7 +28,7 @@ from torch.distributed.tensor import Replicate from modelopt.torch.quantization.config import QuantizerCfgEntry -from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer +from modelopt.torch.quantization.nn import TensorQuantizer from modelopt.torch.utils import get_unwrapped_name, print_rank_0 if TYPE_CHECKING: @@ -866,6 +866,8 @@ def promote_nvfp4_static_quantizers(model: nn.Module) -> int: Returns the number of quantizers converted. """ + from modelopt.torch.quantization.nn import NVFP4StaticQuantizer + converted = 0 for _name, module in list(model.named_modules()): if isinstance(module, TensorQuantizer) and not module._disabled: From 3e75415553da7cacde6db7e8996ba6e97b3da1fd Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 7 Apr 2026 18:50:30 +0000 Subject: [PATCH 50/52] fixed circular import Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/utils/core_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index da899c3419..6e7faf4189 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -28,7 +28,6 @@ from torch.distributed.tensor import Replicate from modelopt.torch.quantization.config import QuantizerCfgEntry -from modelopt.torch.quantization.nn import TensorQuantizer from modelopt.torch.utils import get_unwrapped_name, print_rank_0 if TYPE_CHECKING: @@ -866,7 +865,7 @@ def promote_nvfp4_static_quantizers(model: nn.Module) -> int: Returns the number of quantizers converted. """ - from modelopt.torch.quantization.nn import NVFP4StaticQuantizer + from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer converted = 0 for _name, module in list(model.named_modules()): From e734cee1614552ca71d3529cb41584eb7084e05b Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 7 Apr 2026 20:48:16 +0000 Subject: [PATCH 51/52] minor fix Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index c7f5636b8d..f0fd61798b 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1516,7 +1516,7 @@ class GPTQCalibConfig(QuantizeAlgorithmConfig): """ method: Literal["gptq"] = ModeloptField("gptq") - percdamp: float | None = ModeloptField( + perc_damp: float | None = ModeloptField( default=0.01, gt=0.0, le=1.0, From 4073789338bc1c77e13ca60f96ad92329181d88a Mon Sep 17 00:00:00 2001 From: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Date: Tue, 7 Apr 2026 20:56:56 +0000 Subject: [PATCH 52/52] rebase fix Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 7113bc1423..35a0e931c9 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -135,26 +135,6 @@ def max_calibrate( if hasattr(module, "layer_sync_moe_local_experts_amax"): module.layer_sync_moe_local_experts_amax(sync_weight_amax=sync_expert_weight_amax) - for name, module in list(model.named_modules()): - if isinstance(module, TensorQuantizer) and not module._disabled: - if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): - # Get the initial amax from max calibration - initial_amax = module._amax.clone().detach() - - is_nvfp4_static = ( - module.is_static_block_quant - and module._num_bits == (2, 1) - and module._block_sizes is not None - and module._block_sizes.get("scale_bits") == (4, 3) - ) - - if is_nvfp4_static: - # Compute and set global_amax - global_amax = reduce_amax(initial_amax, axis=None) - - # Convert to NVFP4StaticQuantizer in-place - NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) - if not distributed_sync: return