Skip to content

Add NVTE_BACKWARD_MODE=default|unquant|dequant#2644

Open
zianglih wants to merge 45 commits intoNVIDIA:mainfrom
zianglih:keep-bwd
Open

Add NVTE_BACKWARD_MODE=default|unquant|dequant#2644
zianglih wants to merge 45 commits intoNVIDIA:mainfrom
zianglih:keep-bwd

Conversation

@zianglih
Copy link

@zianglih zianglih commented Feb 3, 2026

Description

@HumansAnd

Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.

Add NVTE_BACKWARD_MODE=default|unquant|dequant env var:

  • default: existing default quantization behavior
  • unquant: quantized fprop + high precision wgrad & dgrad using unquantized activation and weight
  • dequant: quantized fpop + high precision wgrad & dgrad using activation and weight dequantized directly from fprop quantized value

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 3, 2026

Greptile Summary

This PR successfully refactors the backward precision control interface from NVTE_KEEP_BACKWARD_UNQUANTIZED to NVTE_BACKWARD_MODE with three well-defined modes:

  • default: Standard quantized backward pass (existing behavior)
  • unquant: Saves original high-precision operands for backward computation
  • dequant: Saves quantized operands and dequantizes them during backward

Key Changes

  • All recipe classes now include a backward_mode field with proper validation through _resolve_backward_mode()
  • DelayedScaling recipe enforces backward_mode=default only (documented limitation)
  • LayerNormMLP module doesn't support unquant/dequant modes (clear error message provided)
  • Backward fusion is disabled for unquant/dequant modes to maintain high-precision computation
  • Special handling for MXFP8/NVFP4 in dequant mode: optimize_for_gemm is disabled
  • Edge case handling for empty M-dimension in grouped linear operations
  • Comprehensive test suite (1446 lines) with excellent coverage

Implementation Quality

The implementation is thorough and well-structured:

  • All get_fp8_recipe() calls are properly guarded with is_fp8_enabled() checks
  • Tensor usage flags are correctly managed for each mode
  • Memory trade-offs are clear: unquant mode stores high-precision tensors, dequant mode stores quantized tensors
  • Operation fuser properly tracks backward_mode changes to trigger rebuilds

No critical issues found. The code is production-ready.

Confidence Score: 5/5

  • Safe to merge - well-implemented refactor with comprehensive test coverage
  • Excellent implementation quality with proper error handling, edge case coverage, thorough testing (1446 lines), and all known issues from previous reviews addressed. No critical issues found.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/common/recipe/init.py Adds backward_mode field to all recipe classes with proper validation via _resolve_backward_mode(). DelayedScaling correctly enforces backward_mode=default only.
tests/pytorch/test_backward_mode.py Comprehensive test suite (1446 lines) covering all backward modes across different recipes, modules, and edge cases. Tests validate correctness against reference implementations.
transformer_engine/pytorch/ops/basic/basic_linear.py Implements backward_mode logic: saves unquantized tensors for unquant, quantized for dequant. Disables optimize_for_gemm for MXFP8/NVFP4 in dequant mode. All get_fp8_recipe() calls properly guarded.
transformer_engine/pytorch/module/linear.py Updates Linear module to support backward_mode. Properly sets save_original_input=True for unquant mode and handles recipe-specific constraints for MXFP8/NVFP4.
transformer_engine/pytorch/module/layernorm_mlp.py Explicitly asserts that LayerNormMLP doesn't support unquant/dequant modes with clear error message directing users to use LayerNormLinear + Linear instead.
transformer_engine/pytorch/module/layernorm_linear.py Saves high-precision ln_out_hp for unquant mode. Properly manages tensor usage flags and disables fusion for unquant/dequant modes.
transformer_engine/pytorch/module/grouped_linear.py Handles dequant mode with special case for empty M-dimension splits. Properly dequantizes weights and inputs for backward pass in unquant/dequant modes.

Last reviewed commit: 0dee809

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

17 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@zianglih
Copy link
Author

zianglih commented Feb 3, 2026

I'll work on potential unit test breakage.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
if ctx.grad_output_quantizer is not None and use_fp8_bwd:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line seems redundant since you already skip the quantization step in base.py grad_output_preprocess?

not ctx.use_bias
and not ctx.requires_wgrad
and ctx.grad_output_quantizer is not None
and use_fp8_bwd
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

recipe = cls.get_fp8_recipe()
if recipe is not None and recipe.delayed():
# Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's better to assert an error for delayed scaling? Okay with both.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. If the user specifies an unsupported combination, I think it's better to fail loudly than to secretly disobey their instructions.

# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
if ctx.grad_output_quantizer is not None and use_fp8_bwd:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems redundant too if we skip quant in grad_output_preprocess

zianglih and others added 19 commits February 24, 2026 15:28
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@zianglih
Copy link
Author

I have finished the refactor. All new unit tests passed on B200. No new failing tests in the entire pytorch test suite compared to main.

+ python3 -m pytest --tb=auto --junitxml=/sgl-workspace/logs/te_pr_full_pytorch_unittest_20260224_185629/xml/pytest_test_backward_mode.xml /root/TransformerEngine-pr/tests/pytorch/test_backward_mode.py
============================= test session starts ==============================
platform linux -- Python 3.12.3, pytest-8.2.1, pluggy-1.6.0
rootdir: /root/TransformerEngine-pr
configfile: pyproject.toml
plugins: hydra-core-1.3.2, anyio-4.12.1, typeguard-4.4.4
collected 1034 items

tests/pytorch/test_backward_mode.py ........ssssssss.sss...s...s..ss...s [  3%]
...s..ss...s...s..ss...s...s..ssssssssss.sss...s...s..ss...s...s..ss...s [ 10%]
...s..ss...s...s..ssssssssss.sss..........s...........s...........s..... [ 17%]
......s.ssssssss.sss..........s...........s...........s...........s.ssss [ 24%]
ssss.sss...s...s..ss...s...s..ss...s...s..ss...s...s..ssssssssss.sss...s [ 31%]
...s..ss...s...s..ss...s...s..ss...s...s..ssssssssss.sss...s...s..ss...s [ 38%]
...s..ss...s...s..ss...s...s..ssssssssss.sss...s...s..ss...s...s..ss...s [ 45%]
...s..ss...s...s..ssssssssss.sss..........s...........s...........s..... [ 52%]
......s.ssssssss.sss..........s...........s...........s...........s.ssss [ 59%]
ssss.sss...s...s..ss...s...s..ss...s...s..ss...s...s..ssssssssss.sss...s [ 66%]
...s..ss...s...s..ss...s...s..ss...s...s..ss...s...s...s...s.sss.sss...s [ 73%]
...s...s...s.sss.sss...s...s...s...s.sss.sss...s...s...s...s.sss.sss...s [ 80%]
...s...s...s.sss.sss...s...s...s...s.sss.sss.sss.sss..ss..ss.sss.sss..s. [ 87%]
..s..sss.sss..ss..ss.sss.sss..s...s...ss..ss..s...s...ss..ss..s...s...s. [ 94%]
..s...s...s...s...s...ss..s...ss..ss..s...ss.....s.......s....           [100%]

- generated xml file: /sgl-workspace/logs/te_pr_full_pytorch_unittest_20260224_185629/xml/pytest_test_backward_mode.xml -
======================= 632 passed, 402 skipped in 8.56s =======================

@zianglih zianglih marked this pull request as ready for review February 25, 2026 03:55
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

17 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

17 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Comment on lines +1542 to +1546
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_recipe.backward_mode == "dequant" and (fp8_recipe.mxfp8() or fp8_recipe.nvfp4()):
input_quantizer.optimize_for_gemm = False
if grad_output_quantizer is not None:
grad_output_quantizer.optimize_for_gemm = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing comment explaining why optimize_for_gemm must be disabled for MXFP8/NVFP4 in dequant mode

add brief explanation of the constraint

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@ziang-and ziang-and force-pushed the keep-bwd branch 2 times, most recently from 3d64956 to 295d03b Compare February 26, 2026 07:02
Signed-off-by: Ziang Li <ziangli@umich.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants