diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4116a8b6a9..24c50fb889 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -75,6 +75,7 @@ repos: # Instead, we should manually add the license header to those files *after* the original header. exclude: > (?x)^( + modelopt/torch/quantization/utils/calib_utils.py| modelopt/onnx/quantization/operators.py| modelopt/onnx/quantization/ort_patching.py| modelopt/torch/_deploy/utils/onnx_utils.py| diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 7bb1e2322d..f0fd61798b 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1503,24 +1503,20 @@ class SVDQuantConfig(QuantizeAlgorithmConfig): ) -class GPTQLiteConfig(QuantizeAlgorithmConfig): - """The config for GPTQ lite. +class GPTQCalibConfig(QuantizeAlgorithmConfig): + """The config for GPTQ quantization. - GPTQ lite is a variant of GPTQ that does not exactly follow the official GPTQ implementation. - - GPTQ lite does not perform sequential quantization of layers. This means that the updated - activations are not used to process the next layer. + GPTQ minimizes the layer-wise quantization error by using second-order (Hessian) information + to perform blockwise weight updates that compensate for rounding loss. Layers are quantized + sequentially so that each layer's Hessian is computed from activations that already reflect + the quantization of preceding layers. The default values are taken from the official GPTQ implementation: https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L35 - - Note: This feature is currently experimental and may not translate to improved accuracy as expected. - - """ - method: Literal["gptq_lite"] = ModeloptField("gptq_lite") - percdamp: float | None = ModeloptField( + method: Literal["gptq"] = ModeloptField("gptq") + perc_damp: float | None = ModeloptField( default=0.01, gt=0.0, le=1.0, @@ -1533,12 +1529,6 @@ class GPTQLiteConfig(QuantizeAlgorithmConfig): description="""The block size for GPTQ weight update, which must be a multiple of the group_size used in the quantization.""", ) - hessian_state_path: str | None = ModeloptField( - default=None, - title="Path to the Hessian state file.", - description="""The path to the Hessian state file. If hessian path exists, we load from - hessian file instead of recomputing them.""", - ) QuantizeQuantCfgType = list[QuantizerCfgEntry] diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index e08efece9a..c81d5c89c7 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -37,7 +37,7 @@ AWQFullCalibConfig, AWQLiteCalibConfig, CompressConfig, - GPTQLiteConfig, + GPTQCalibConfig, LocalHessianCalibConfig, MaxCalibConfig, MseCalibConfig, @@ -59,7 +59,7 @@ ) from .model_calib import ( awq, - gptq_lite, + gptq, local_hessian_calibrate, max_calibrate, mse_calibrate, @@ -240,8 +240,8 @@ def wrapped_calib_func( if sequential: if forward_loop is None: raise ValueError("forward_loop is required for calibration but got None.") - assert method in ["max"], ( - f"Sequential calibration currently only supports max calibration, got {method}" + assert method in ["max", "gptq"], ( + f"Sequential calibration currently only supports max and gptq calibration, got {method}" ) # Wrap with sequential processing sequential_calibrate( @@ -493,12 +493,12 @@ def restore(self) -> RestoreEntrypoint: @CalibrateModeRegistry.register_mode -class GPTQLiteModeDescriptor(BaseCalibrateModeDescriptor): +class GPTQModeDescriptor(BaseCalibrateModeDescriptor): """Mode for GPTQ calibration algorithm.""" @property def config_class(self) -> type[QuantizeAlgorithmConfig]: """Specifies the config class for the mode.""" - return GPTQLiteConfig + return GPTQCalibConfig - _calib_func = gptq_lite + _calib_func = gptq diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 1982fee716..35a0e931c9 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -16,7 +16,7 @@ """Calibration utilities.""" import math -import os +import time import warnings from collections.abc import Callable from functools import partial @@ -32,7 +32,6 @@ from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method -from modelopt.torch.utils.perf import get_used_gpu_mem_fraction from .calib import MseCalibrator, NVFP4MSECalibrator from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context @@ -45,10 +44,12 @@ is_quantized_column_parallel_linear, is_quantized_linear, is_quantized_row_parallel_linear, + promote_nvfp4_static_quantizers, quantizer_attr_names, reduce_amax, weight_attr_names, ) +from .utils.calib_utils import GPTQHelper __all__ = [ "awq", @@ -1550,324 +1551,6 @@ def postprocess(module, name): max_calibrate(model, forward_loop) -def _print_relative_mse_error(q: torch.Tensor, w: torch.Tensor, h: torch.Tensor, module_name: str): - """Print relative mean squared error between quantized and original weights. - - Computes the Hessian-weighted relative MSE between quantized and original weights, - providing a measure of quantization quality. This metric is adapted from the GPTQ - repository. - - Args: - q (torch.Tensor): Quantized weight tensor - w (torch.Tensor): Original weight tensor - h (torch.Tensor): Hessian matrix used for weighting the error - module_name (str): Name of the module for logging purposes - Note: - Implementation adapted from the GPTQ repository: - https://github.com/IST-DASLab/FP-Quant - """ - delta = q - w - mse = (delta).mm(h).mul(delta).mean() / (w.mm(h).mul(w).mean() + 1e-6) - print(f"[{module_name}] Relative MSE error: {mse.item():.2e}") - - -def update_hessian(input, hessian, n_samples): - """Update hessian matrix with new input samples using incremental formula. - - Args: - input: Input tensor (batch_size, ..., features) - hessian: Current Hessian matrix to update in-place - n_samples: Number of samples already processed - Returns: - Tuple of (updated_hessian, new_sample_count) - """ - batch_size = input.shape[0] - - # Incremental averaging: scale down old hessian - hessian *= n_samples / (n_samples + batch_size) - n_samples += batch_size - - # Compute outer product: H += (2/n_samples) * X @ X^T - # where X is the flattened input reshaped to (features, batch*seq) - input_flat = input.reshape(-1, input.shape[-1]).t().float() - scaled_input = math.sqrt(2 / n_samples) * input_flat - hessian.add_((scaled_input @ scaled_input.t()).to(hessian.device)) - - return hessian, n_samples - - -def prepare_hessian_inverse(h, weight, percdamp): - """Prepare inverse Hessian with dead neuron handling and damping. - - Args: - h: Hessian matrix to update - weight: Weight tensor to prepare Hessian for - percdamp: Damping percentage for Hessian diagonal - Returns: - h_inv: Inverse Hessian matrix - Implementation adapted from the FP-Quant repository: - https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L200 - """ - h = h.clone() - # Handle dead neurons (zero weight columns) - # Get columns with all zeros in weight - zero_cols = torch.nonzero(weight.eq(0).all(dim=0)).unsqueeze(-1) - - # Zero out entire rows and columns in Hessian for dead neurons - h[zero_cols, :] = 0 - h[:, zero_cols] = 0 - h[zero_cols, zero_cols] = 1 - - # Add damping to diagonal - damp = percdamp * torch.mean(torch.diag(h)) - diag_indices = torch.arange(h.shape[0], device=h.device) - h[diag_indices, diag_indices] += damp - - try: - h = torch.cholesky_inverse(torch.linalg.cholesky(h)) - h_inv = torch.linalg.cholesky(h, upper=True) - except (RuntimeError, torch.linalg.LinAlgError): - print("Warning: Hessian is not positive definite, using identity matrix") - h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype) - return h_inv - - -def quantize_block(full_weight, block_start, block_end, h_inv, quantizer): - """Quantize a block of weights group by group (based on quantizer block sizes) with error propagation. - - Args: - full_weight: The full weight tensor (needed for INT4 quantization) - block_start: Starting column index of the block - block_end: Ending column index of the block - h_inv: Hessian inverse - quantizer: The quantizer to apply - Returns: - quantized_block: Quantized weights for this block - losses: Quantization losses per element - errors: Accumulated errors for propagation - """ - # Extract the block we're working on - block_weight = full_weight[:, block_start:block_end] - block_hinv = h_inv[block_start:block_end, block_start:block_end] - block_size = block_end - block_start - - quantized_block = torch.zeros_like(block_weight) - losses = torch.zeros_like(block_weight) - errors = torch.zeros_like(block_weight) - - # We perform column-wise update for GPTQ within the block - group_size = 1 - - for group_start in range(0, block_size, group_size): - group_end = min(group_start + group_size, block_size) - group_cols = slice(group_start, group_end) - # Get current column and its Hessian inverse diagonal - weight_col = block_weight[:, group_cols] - hinv_diag = torch.diag(block_hinv[group_cols, group_cols]) - - # Quantize using the full weight, then extract the columns we need - quantized_full = quantizer(full_weight) - quantized_cols = quantized_full[:, block_start + group_start : block_start + group_end] - quantized_block[:, group_cols] = quantized_cols - - # Compute quantization error and loss - error = (weight_col - quantized_cols) / hinv_diag - losses[:, group_cols] = (weight_col - quantized_cols) ** 2 / (hinv_diag**2) / 2 - errors[:, group_cols] = error - - # Propagate error to remaining columns in block - block_weight[:, group_start:] -= error @ block_hinv[group_start:group_end, group_start:] - full_weight[:, block_start:block_end] = block_weight - - return quantized_block, losses, errors - - -def blockwise_weight_update(module, h, block_size, percdamp): - """Update module weights using GPTQ-style blockwise quantization. - - Args: - module: Neural network module with weight and weight_quantizer - H: Hessian matrix (d x d) - block_size: Size of blocks to process at once - percdamp: Damping percentage for Hessian diagonal - """ - weight = module.weight.data.float().clone() - _, num_cols = weight.shape - - # Preprocess Hessian: handle dead neurons and add damping - h_inv = prepare_hessian_inverse(h, weight, percdamp) - - # Initialize output tensors - quantized_weight = torch.zeros_like(weight) - losses = torch.zeros_like(weight) - - # Process weights in blocks - for block_start in range(0, num_cols, block_size): - block_end = min(block_start + block_size, num_cols) - - quantized_block, block_losses, block_errors = quantize_block( - weight, block_start, block_end, h_inv, module.weight_quantizer - ) - # Store results - quantized_weight[:, block_start:block_end] = quantized_block - losses[:, block_start:block_end] = block_losses - - # Propagate errors to remaining weights - weight[:, block_end:] -= block_errors @ h_inv[block_start:block_end, block_end:] - - # Print relative mse error - _print_relative_mse_error(quantized_weight, module.weight.float(), h, module.name) - # Update module weights - module.weight.data = quantized_weight.reshape(module.weight.shape).to(module.weight.data.dtype) - - -def gptq_lite( - model: nn.Module, - forward_loop: ForwardLoop | None = None, - percdamp: float = 0.01, - block_size: int = 128, - hessian_state_path: str | None = None, -): - """GPTQ-lite quantization - a simplified GPTQ variant. - - Key differences from GPTQ: - - Layers are quantized in parallel (not sequentially with updated activations) - - Uses group-wise updates instead of column-wise updates - - Args: - model: Model to be calibrated. - forward_loop: Callable that forwards calibration data through the model. - percdamp: Percentage of avg Hessian diagonal for damping (default: 0.01). - block_size: Block size for GPTQ weight update. - hessian_state_path: Path to save/load Hessian state. If None, compute without saving. - If path exists, load from it. If path doesnt exist then save computed hessians to path. - - See :class:`GPTQLiteConfig ` for - details on the remaining arguments. - - Note: This feature is currently experimental and may not translate to improved accuracy as expected. - """ - # Dictionary to store hessian matrices: {layer_name: {"hessian": Tensor, "n_samples": int}} - hessian_state = {} - - def initialize_hessian_state(tensor_mapping): - """Initialize hessian state with zeros.""" - for name, (shape, device) in tensor_mapping.items(): - # Use CPU if GPU memory is tight - target_device = "cpu" if get_used_gpu_mem_fraction(device) > 0.65 else device - hessian_state[name] = { - "hessian": torch.zeros(shape, dtype=torch.float32, device=target_device), - "n_samples": 0, - } - - def load_hessian_state(path, tensor_mapping): - """Load hessian state from file.""" - print_rank_0(f"Loading hessian state from {path}") - loaded_state = torch.load(path, map_location="cpu") - - for name, (shape, device) in tensor_mapping.items(): - if name not in loaded_state: - raise KeyError(f"Layer '{name}' not found in loaded hessian state") - - # Move to appropriate device based on memory - target_device = "cpu" if get_used_gpu_mem_fraction(device) > 0.65 else device - hessian_state[name] = { - "hessian": loaded_state[name]["hessian"].to(target_device), - "n_samples": loaded_state[name]["n_samples"], - } - - print_rank_0(f"Successfully loaded hessian state with {len(hessian_state)} layers") - - def save_hessian_state(path): - """Save hessian state to file.""" - print_rank_0(f"Saving hessian state to {path}") - try: - # Move to CPU for saving - cpu_state = { - name: {"hessian": state["hessian"].cpu(), "n_samples": state["n_samples"]} - for name, state in hessian_state.items() - } - - os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True) - torch.save(cpu_state, path) - print_rank_0(f"Successfully saved hessian state to {path}") - except Exception as e: - print_rank_0(f"Error saving hessian state: {e}") - print_rank_0("Continuing execution...") - - def hessian_hook(module, input, output): - """Hook to intercept activations and update hessian matrix.""" - state = hessian_state[module.name] - hessian, n_samples = update_hessian(input[0], state["hessian"], state["n_samples"]) - hessian_state[module.name] = {"hessian": hessian, "n_samples": n_samples} - - # Phase 1: Collect statistics for quantizers - max_calibrate(model) - - # Phase 2: Build tensor mapping for all quantized layers - tensor_mapping = {} - for name, module in model.named_modules(): - if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - in_features = module.weight.shape[-1] - tensor_mapping[name] = ((in_features, in_features), module.weight.device) - module.name = name # Attach name for easy access in hooks - - # Phase 3: Load or compute Hessians - hessian_exists = hessian_state_path is not None and os.path.exists(hessian_state_path) - save_hessians = hessian_state_path is not None and not hessian_exists - - if hessian_exists: - print_rank_0(f"Loading hessian state from {hessian_state_path}") - load_hessian_state(hessian_state_path, tensor_mapping) - else: - if forward_loop is None: - raise ValueError("forward_loop must be provided when computing Hessians") - - # Initialize hessian state - initialize_hessian_state(tensor_mapping) - - # Register hooks to collect activations - handles = [] - for name, module in model.named_modules(): - if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - handles.append(module.register_forward_hook(hessian_hook)) - - # Run forward loop to compute hessians - print_rank_0("Computing Hessian matrices...") - forward_loop(model) - - for handle in handles: - handle.remove() - - # Save if configured - if save_hessians: - try: - save_hessian_state(hessian_state_path) - except Exception as e: - print_rank_0(f"Error saving hessian state: {e}") - print_rank_0("Continuing execution...") - - # Phase 4: Update weights using computed Hessians - print_rank_0("Updating weights using GPTQ-lite algorithm...") - - quantized_modules = [ - (name, module) - for name, module in model.named_modules() - if is_quantized_linear(module) and module.weight_quantizer.is_enabled - ] - - # Perform blockwise weight updates - for name, module in tqdm(quantized_modules, desc="Quantizing layers"): - state = hessian_state[module.name] - hessian = state["hessian"].to(module.weight.device) - blockwise_weight_update(module, hessian, block_size, percdamp) - # Delete hessian state to free memory - del hessian_state[module.name] - torch.cuda.empty_cache() - - print_rank_0("GPTQ-lite quantization completed successfully") - - @torch.no_grad() def sequential_calibrate( model: nn.Module, @@ -1914,3 +1597,77 @@ def _layer_forward_loop(m, _inputs=layer_inputs): torch.cuda.empty_cache() finally: input_getter._unpatch_all_layers() + + print_rank_0("Sequential calibration completed") + + +@torch.no_grad() +def gptq( + model: nn.Module, + forward_loop: ForwardLoop, + perc_damp: float = 0.01, + block_size: int = 128, +): + """GPTQ quantization. + + Works in two modes depending on ``use_sequential`` in the config: + + * **Sequential** (``use_sequential=True``): ``sequential_calibrate`` calls this + function once per decoder layer with updated activations, producing more + accurate Hessian estimates. + * **Non-sequential** (``use_sequential=False``): called once on the full model. + All layers are quantized in parallel from the original activations. + + Per-module steps: + + 1. ``max_calibrate`` to set amax values from the current activations. + 2. Promote eligible quantizers to ``NVFP4StaticQuantizer`` (two-level scaling). + 3. Collect per-linear-layer Hessian matrices via forward hooks. + 4. Blockwise weight updates using the inverse Hessian to compensate for + rounding error (the core GPTQ column-wise update). + + Args: + model: The module to quantize — either the full model or a single decoder + layer when invoked by ``sequential_calibrate``. + forward_loop: Callable that replays calibration inputs through *model*. + perc_damp: Percentage of avg Hessian diagonal for damping (default: 0.01). + block_size: Block size for GPTQ weight update. + """ + total_start = time.time() + + # TODO: Add support for other scale setting strateiges like weight-mse or local-hessian + max_calibrate(model, forward_loop=forward_loop) + promote_nvfp4_static_quantizers(model) + + quantized_layers = [ + (n, m) + for n, m in model.named_modules() + if is_quantized_linear(m) and m.weight_quantizer.is_enabled + ] + if not quantized_layers: + print_rank_0("No quantized linear layers found, skipping GPTQ") + return + + gptq_handles = {name: GPTQHelper(m, name, offload_to_cpu=True) for name, m in quantized_layers} + for handle in gptq_handles.values(): + handle.setup() + + print_rank_0(f"Computing Hessians for {len(gptq_handles)} linear layers...") + + with set_quantizer_by_cfg_context( + model, [{"quantizer_name": "*weight_quantizer", "enable": False}] + ): + forward_loop(model) + + for handle in gptq_handles.values(): + handle.cleanup() + + print_rank_0("Updating weights using GPTQ algorithm...") + for handle in gptq_handles.values(): + handle.update_weights(block_size, perc_damp) + handle.free() + del gptq_handles + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + print_rank_0(f"GPTQ time: {time.time() - total_start:.2f}s") diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index 746e391f3f..5e65f9cc1d 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -564,9 +564,6 @@ def get_auto_quantize_config(search_state, constraints=None, verbose=False): # Or use the original result config = mtq.get_auto_quantize_config(search_state) - # [Optional] Customize algorithm if needed - config["algorithm"] = {"method": "gptq_lite", "sequential": True} - # Reuse on the same model (e.g. run a longer calibration pass) model = mtq.quantize(model, config, forward_loop=calibrate_loop) diff --git a/modelopt/torch/quantization/utils/calib_utils.py b/modelopt/torch/quantization/utils/calib_utils.py new file mode 100644 index 0000000000..b1d77677b7 --- /dev/null +++ b/modelopt/torch/quantization/utils/calib_utils.py @@ -0,0 +1,233 @@ +# Adapted from https://github.com/IST-DASLab/FP-Quant/blob/d2e3092/src/quantization/gptq.py +# with minor modifications to the original forms to accommodate minor architectural differences +# to be reused in the Model-Optimizer pipeline. +# Copyright (c) Andrei Panferov +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPTQ helper and Hessian utilities for calibration.""" + +import math + +import torch + +from modelopt.torch.utils import print_rank_0 +from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method +from modelopt.torch.utils.perf import get_used_gpu_mem_fraction + + +def update_hessian(input, hessian, n_samples): + """Update hessian matrix with new input samples using incremental formula. + + Args: + input: Input tensor (batch_size, ..., features) + hessian: Current Hessian matrix to update in-place + n_samples: Number of samples already processed + Returns: + Tuple of (updated_hessian, new_sample_count) + + Note: input must be non-empty (batch_size > 0); a zero-sized input causes division by zero. + """ + # Flatten to 2D (total_tokens, features) first, so batch_size counts tokens + input_flat = input.reshape(-1, input.shape[-1]).t().float() + batch_size = input_flat.shape[1] + + # Incremental averaging: scale down old hessian + hessian *= n_samples / (n_samples + batch_size) + n_samples += batch_size + + # Compute outer product: H += (2/n_samples) * X @ X^T + scaled_input = math.sqrt(2 / n_samples) * input_flat + hessian.add_((scaled_input @ scaled_input.t()).to(hessian.device)) + + return hessian, n_samples + + +class GPTQHelper: + """Encapsulates per-module GPTQ state and operations. + + Owns the Hessian, patches the forward during collection, and contains + the blockwise weight-update logic. + + Instance attributes set during ``__init__``: + module, name, hessian, n_samples + + Instance attributes set during ``update_weights``: + weight: float working copy of module weights (mutated in-place by update methods) + h_inv: upper-triangular Cholesky factor of the damped inverse Hessian + """ + + CACHE_NAME = "_forward_no_gptq_hessian" + + def __init__(self, module, name, offload_to_cpu=False): + """Initialize GPTQHelper with module state and Hessian storage.""" + self.module = module + self.name = name + in_features = module.weight.shape[-1] + device = module.weight.device + if offload_to_cpu and get_used_gpu_mem_fraction(device) > 0.65: + device = "cpu" + self.hessian = torch.zeros(in_features, in_features, dtype=torch.float32, device=device) + self.n_samples = 0 + # Set by update_weights(); listed here for documentation. + self.weight: torch.Tensor | None = None + self.h_inv: torch.Tensor | None = None + + def setup(self): + """Patch the module's forward to accumulate Hessian during the collection pass.""" + gptq_helper = self + + def hessian_forward(self, input, *args, **kwargs): + inp = input.to_local() if hasattr(input, "to_local") else input + if self.input_quantizer is not None and self.input_quantizer.is_enabled: + hessian_input = self.input_quantizer(inp) + else: + hessian_input = inp + gptq_helper.hessian, gptq_helper.n_samples = update_hessian( + hessian_input, gptq_helper.hessian, gptq_helper.n_samples + ) + + out = self._forward_no_gptq_hessian(input, *args, **kwargs) + + return out + + bind_forward_method(self.module, hessian_forward, self.CACHE_NAME) + + def cleanup(self): + """Unpatch the module's forward method.""" + unpatch_forward_method(self.module, self.CACHE_NAME) + + def free(self): + """Release Hessian and working tensors to reclaim memory.""" + self.hessian = None + self.weight = None + self.h_inv = None + + def update_weights(self, block_size, perc_damp): + """Run GPTQ blockwise weight update on this module. + + Populates ``self.weight`` and ``self.h_inv``, runs the blockwise update, + logs MSE, and writes the result back to the module. + """ + hessian = self.hessian.to(self.module.weight.device) + self.weight = self.module.weight.data.float().clone() + self._prepare_hessian_inverse(hessian, perc_damp) + + self._blockwise_update(block_size) + + self._print_mse_error(hessian) + self.module.weight.data = self.weight.reshape(self.module.weight.shape).to( + self.module.weight.data.dtype + ) + + # ------------------------------------------------------------------ + # Quantize helpers — all read from self.module, self.weight, self.h_inv + # ------------------------------------------------------------------ + + def _prepare_hessian_inverse(self, hessian, perc_damp): + """Compute damped inverse Hessian and store as ``self.h_inv``. + + Dead-neuron columns (all-zero in ``self.weight``) are zeroed in the + Hessian before inversion, matching the FP-Quant reference: + https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L200 + """ + assert self.weight is not None, "_prepare_hessian_inverse called before update_weights()" + h = hessian.clone() + zero_cols = torch.nonzero(self.weight.eq(0).all(dim=0)).unsqueeze(-1) + + h[zero_cols, :] = 0 + h[:, zero_cols] = 0 + h[zero_cols, zero_cols] = 1 + + damp = perc_damp * torch.mean(torch.diag(h)) + diag_indices = torch.arange(h.shape[0], device=h.device) + h[diag_indices, diag_indices] += damp + + try: + h = torch.cholesky_inverse(torch.linalg.cholesky(h)) + self.h_inv = torch.linalg.cholesky(h, upper=True) + except (RuntimeError, torch.linalg.LinAlgError): + print_rank_0("Warning: Hessian is not positive definite, using identity matrix") + self.h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype) + + def _blockwise_update(self, block_size): + """Column-wise GPTQ update using full-matrix QDQ. + + For each column, quantizes the full weight matrix via the quantizer and + extracts the quantized column. This is the standard GPTQ approach. + + Reads/writes ``self.weight`` and ``self.h_inv`` in-place. + """ + assert self.weight is not None and self.h_inv is not None, ( + "_blockwise_update called before _prepare_hessian_inverse()" + ) + quantizer = self.module.weight_quantizer + block_sizes = getattr(quantizer, "block_sizes", None) + if block_sizes is not None: + group_size = block_sizes.get(-1) + if group_size is not None and block_size % group_size != 0: + raise ValueError( + f"GPTQ block_size ({block_size}) must be divisible by the quantizer" + f" group_size ({group_size})" + ) + num_cols = self.weight.shape[1] + + for block_start in range(0, num_cols, block_size): + block_end = min(block_start + block_size, num_cols) + n_cols_blk = block_end - block_start + h_inv_cho_blk = self.h_inv[block_start:block_end, block_start:block_end] + + wblk = self.weight.clone() + errs = torch.zeros_like(wblk[:, block_start:block_end]) + + for i in range(n_cols_blk): + w_ci = wblk[:, block_start + i] + d = h_inv_cho_blk[i, i] + qdq = quantizer(wblk) + self.weight[:, block_start + i] = qdq[:, block_start + i] + err = (w_ci - qdq[:, block_start + i]) / d + wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1) + errs[:, i] = err + + self.weight[:, block_end:].addmm_( + errs, self.h_inv[block_start:block_end, block_end:], alpha=-1 + ) + + def _print_mse_error(self, hessian): + """Log Hessian-weighted relative MSE between ``self.weight`` and original weights.""" + w_orig = self.module.weight.float() + delta = self.weight - w_orig + mse = (delta).mm(hessian).mul(delta).mean() / (w_orig.mm(hessian).mul(w_orig).mean() + 1e-6) + suffix = f", n_hessian_samples: {self.n_samples}" if self.n_samples else "" + print_rank_0(f"[{self.name}] Relative MSE error: {mse.item():.2e}{suffix}") diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 5a2fe37ad5..6e7faf4189 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -854,3 +854,32 @@ def update_quant_cfg_with_kv_cache_quant( quant_cfg["algorithm"] = "max" print_rank_0(f"Updated quant_cfg with KV cache quantization: {quant_cfg}") return quant_cfg + + +def promote_nvfp4_static_quantizers(model: nn.Module) -> int: + """Convert eligible TensorQuantizers to NVFP4StaticQuantizer in-place. + + After max calibration sets per-block amax values, NVFP4 static quantizers + need to be promoted so they use the two-level scaling path (global amax + + per-block amax) instead of the generic E4M3 path. + + Returns the number of quantizers converted. + """ + from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer + + converted = 0 + for _name, module in list(model.named_modules()): + if isinstance(module, TensorQuantizer) and not module._disabled: + if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): + is_nvfp4_static = ( + module.is_static_block_quant + and module._num_bits == (2, 1) + and module._block_sizes is not None + and module._block_sizes.get("scale_bits") == (4, 3) + ) + if is_nvfp4_static: + initial_amax = module._amax.clone().detach() + global_amax = reduce_amax(initial_amax, axis=None) + NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) + converted += 1 + return converted diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index d43177cae2..8867854737 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -17,10 +17,14 @@ import pytest import torch -from transformers import AutoModelForCausalLM, AutoTokenizer +from _test_utils.torch.transformers_models import get_tiny_llama +from transformers import AutoTokenizer import modelopt.torch.quantization as mtq -from modelopt.torch.quantization.model_calib import blockwise_weight_update, update_hessian +from modelopt.torch.export.unified_export_hf import _export_quantized_weight +from modelopt.torch.quantization.model_calib import gptq +from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor +from modelopt.torch.quantization.utils.calib_utils import update_hessian from modelopt.torch.utils.dataset_utils import create_forward_loop, get_dataset_dataloader RAND_SEED = 42 @@ -46,8 +50,11 @@ def test_update_hessian(): f"Expected hessian shape ({features}, {features}), got {updated_hessian.shape}" ) - # Verify sample count is updated correctly (incremented by batch_size) - assert new_n_samples == batch_size, f"Expected n_samples={batch_size}, got {new_n_samples}" + # Verify sample count is updated correctly (incremented by total tokens = batch * seq_len) + expected_n_samples = batch_size * seq_len + assert new_n_samples == expected_n_samples, ( + f"Expected n_samples={expected_n_samples}, got {new_n_samples}" + ) # Verify hessian is not all zeros after update assert not torch.allclose(updated_hessian, torch.zeros_like(updated_hessian)), ( @@ -70,22 +77,23 @@ def test_update_hessian(): # Manual calculation: # input_flat shape: (features, batch*seq) = (2, 12), all ones - # scaled_input = sqrt(2/6) * input_flat = sqrt(1/3) * ones(2, 12) - # outer_product = scaled_input @ scaled_input.t() = (2/6) * ones(2,12) @ ones(12,2) = [[4,4], [4,4]] - # Note: The scaling factor is (2/n_samples), so with n_samples=6 and 12 tokens: (2/6)*12 = 4 - expected_hessian = torch.ones(features, features, dtype=torch.float32) * 4.0 + # n_samples = batch * seq = 12 (token count after flattening) + # scaled_input = sqrt(2/12) * ones(2, 12) + # outer_product = (2/12) * ones(2,12) @ ones(12,2) = [[2,2], [2,2]] + expected_n_samples = batch_size * seq_len # 12 tokens + expected_hessian = torch.ones(features, features, dtype=torch.float32) * 2.0 assert torch.allclose(updated_hessian, expected_hessian, atol=1e-5), ( f"Expected hessian {expected_hessian}, got {updated_hessian}" ) - assert new_n_samples == batch_size + assert new_n_samples == expected_n_samples # Test 3: Accumulated hessians - verify equivalence # Processing [6,2,2] in one step should equal processing [2,2,2] three times seq_len = 2 features = 2 - # Process in 3 steps of batch_size=2 + # Process in 3 steps of batch_size=2 (4 tokens each, 12 total) hessian_accumulated = torch.zeros(features, features, dtype=torch.float32) n_samples_accumulated = 0 @@ -102,17 +110,18 @@ def test_update_hessian(): assert torch.allclose(hessian_accumulated, expected_hessian, atol=1e-5), ( f"Accumulated hessian should match expected: expected {expected_hessian}, got {hessian_accumulated}" ) - assert n_samples_accumulated == 6, f"Expected n_samples=6, got {n_samples_accumulated}" + # 3 batches * 2 batch_size * 2 seq_len = 12 tokens + assert n_samples_accumulated == 12, f"Expected n_samples=12, got {n_samples_accumulated}" @pytest.mark.parametrize( ("block_size", "dim", "model_weight", "expect_weight_change"), [ - (4, 16, torch.randn(16, 16).to("cuda"), True), # random weight + (16, 128, torch.randn(128, 128).to("cuda"), True), # random weight ( - 4, 16, - torch.ones(16, 16).to("cuda"), + 128, + torch.ones(128, 128).to("cuda"), False, ), # all same weight -> no quantization error -> no GPTQ update ], @@ -120,35 +129,20 @@ def test_update_hessian(): def test_gptq_updates(block_size, dim, model_weight, expect_weight_change): model = torch.nn.Linear(dim, dim).to("cuda") model.weight.data = model_weight - model.name = "linear" original_weight = model_weight.clone() - input = torch.randn(2, 16, dim).to("cuda") - hessian = torch.zeros(dim, dim).to("cpu") - n_samples = 0 + input_tensor = torch.randn(2, 16, dim).to("cuda") quant_cfg = mtq.NVFP4_DEFAULT_CFG - mtq.quantize(model, quant_cfg, forward_loop=lambda model: model(input)) + mtq.quantize(model, quant_cfg, forward_loop=lambda model: model(input_tensor)) # Get qdq weight q_dq_weight = model.weight_quantizer(model.weight.data) - # Restore original weight + # Restore original weight before GPTQ model.weight.data = original_weight.clone() - hessian, n_samples = update_hessian(input, hessian, n_samples) - - # Verify n_samples is update using hessian matrix - assert n_samples == input.shape[0], "n_samples should be equal to input.shape[0]" - - # Perform another forward pass to update hessian matrix - input_2 = torch.randn(3, 16, dim).to("cuda") - hessian, n_samples = update_hessian(input_2, hessian, n_samples) - assert n_samples == input.shape[0] + input_2.shape[0], ( - "n_samples should be equal to input.shape[0] + input_2.shape[0]" - ) - - hessian = hessian.to(input.device) - blockwise_weight_update(model, hessian, block_size, 0.1) + # Run GPTQ through the public API + gptq(model, forward_loop=lambda m: m(input_tensor), perc_damp=0.1, block_size=block_size) if expect_weight_change: # Weight must change as GPTQ updates weights to adjust for quantization error assert not torch.allclose(model.weight.data, q_dq_weight), "Weight should not be equal" @@ -156,53 +150,85 @@ def test_gptq_updates(block_size, dim, model_weight, expect_weight_change): assert torch.allclose(model.weight.data, q_dq_weight), "Weight should be equal" +def test_gptq_export_roundtrip(): + """Test that GPTQ export + dequantize produces weights matching in-memory QDQ.""" + torch.manual_seed(RAND_SEED) + dim = 128 + block_size = 16 + + # Step 1: Create a simple linear model and quantize to install NVFP4 quantizers + model = torch.nn.Linear(dim, dim, dtype=torch.bfloat16).to("cuda") + original_weight = model.weight.data.clone() + input_tensor = torch.randn(2, 16, dim, dtype=torch.bfloat16).to("cuda") + quant_cfg = mtq.NVFP4_DEFAULT_CFG + + mtq.quantize(model, quant_cfg, forward_loop=lambda m: m(input_tensor)) + + # Restore original weight before GPTQ + model.weight.data = original_weight.clone() + + # Step 2: Perform GPTQ — compute Hessian and update weights + gptq(model, forward_loop=lambda m: m(input_tensor), perc_damp=0.1, block_size=block_size) + + # Save the QDQ reference from the quantizer applied to GPTQ'd weights + gptq_weight_shape = model.weight.data.shape + gptq_weight_dtype = model.weight.data.dtype + qdq_ref = model.weight.data.clone() + + # Step 3: Export — converts weight to packed NVFP4 and registers scale buffers + _export_quantized_weight(model, torch.bfloat16) + + # Verify export produced the expected buffers + assert hasattr(model, "weight_scale"), "Export should register weight_scale buffer" + assert hasattr(model, "weight_scale_2"), "Export should register weight_scale_2 buffer" + + # Step 4: Dequantize the exported packed weight and compare with QDQ reference + packed_weight = model.weight.data + weight_scale = model.weight_scale + weight_scale_2 = model.weight_scale_2 + + nvfp4_qtensor = NVFP4QTensor(gptq_weight_shape, gptq_weight_dtype, packed_weight) + deq_weight = nvfp4_qtensor.dequantize( + dtype=torch.bfloat16, + scale=weight_scale, + double_scale=weight_scale_2, + block_sizes={-1: 16}, + ) + + assert deq_weight.shape == qdq_ref.shape, ( + f"Shape mismatch: dequantized {deq_weight.shape} vs QDQ ref {qdq_ref.shape}" + ) + assert torch.allclose(deq_weight, qdq_ref, atol=1e-2), ( + f"Dequantized weight does not match QDQ reference. " + f"Max diff: {(deq_weight - qdq_ref).abs().max().item()}" + ) + + @pytest.mark.parametrize( "quant_cfg", [mtq.NVFP4_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG] ) def test_gptq_e2e_flow(quant_cfg): - model = AutoModelForCausalLM.from_pretrained( - "TinyLlama/TinyLlama-1.1B-Chat-v1.0", device_map="auto" - ) - tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + model = get_tiny_llama(vocab_size=tokenizer.vocab_size).to("cuda") - # can't set attribute 'pad_token' for "" - # We skip this step for Nemo models - if tokenizer.pad_token != "" or tokenizer.pad_token is None: + if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - # Left padding usually provides better calibration result. tokenizer.padding_side = "left" assert tokenizer.pad_token is not None, "Pad token cannot be set!" model.eval() quant_cfg = copy.deepcopy(quant_cfg) - quant_cfg["algorithm"] = "gptq_lite" - # Define quantizer/dataloader + quant_cfg["algorithm"] = {"method": "gptq", "use_sequential": True} calib_dataloader = get_dataset_dataloader( dataset_name="cnn_dailymail", tokenizer=tokenizer, - batch_size=32, - num_samples=512, + batch_size=2, + num_samples=8, device="cuda", include_labels=False, ) - # Only run single sample for preview - prompt = "Where is New York city?" - input_ids = tokenizer(prompt, return_tensors="pt") - print(f"Input ids: {input_ids}") - generated_ids_before_ptq = model.generate( - input_ids["input_ids"].to("cuda"), max_new_tokens=100, do_sample=False, temperature=0.0 - ) - print( - f"Generated ids before quantization: {tokenizer.decode(generated_ids_before_ptq[0], skip_special_tokens=True)}" - ) calibrate_loop = create_forward_loop(dataloader=calib_dataloader) model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - generated_ids_after_ptq = model.generate( - input_ids["input_ids"].to("cuda"), max_new_tokens=100, do_sample=False, temperature=0.0 - ) - print( - f"Generated ids after quantization: {tokenizer.decode(generated_ids_after_ptq[0], skip_special_tokens=True)}" - )