Skip to content

[PyTorch] Introduce quantizer roles#2620

Open
negvet wants to merge 27 commits intoNVIDIA:mainfrom
negvet:semantic_quantizer_roles
Open

[PyTorch] Introduce quantizer roles#2620
negvet wants to merge 27 commits intoNVIDIA:mainfrom
negvet:semantic_quantizer_roles

Conversation

@negvet
Copy link
Collaborator

@negvet negvet commented Jan 23, 2026

Description

Introducing QuantizerRole

@dataclasses.dataclass(frozen=True)
class QuantizerRole:
    module_type: str = ""   # e.g. "linear", "grouped_linear", "dpa"
    tensor_type: str = ""   # e.g. "input", "weight", "grad_output", "qkv", "s"
    name: str = ""          # instance name, e.g. "qkv", "proj", "fc1", "fc2"

This is an API that allows to go down to "set this LayerNormLinear in this transformer layer to be less aggressively quantized." (fine-grained, per-module/per-tensor quantization control mechanism)
See test_custom_recipe.py::test_custom_recipe_quantization_targets().

Quantizer factory uses roles to dispatch according to its needs.

TE module/op emits a list of QuantizerRole:

  • Linear, LayerNormLinear, LayerNormMLP emit module_type="linear" with tensor_type in {"input", "weight", "grad_output"}.
  • GroupedLinear emits module_type="grouped_linear".

CustomRecipe accepts a qfactory callable that receives QuantizerRole and returns a quantizer.

Factories can be composed - e.g., dispatch (to different sub-factories as an option) based on module_type (dpa vs linear) and then refine based on tensor_type.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

negvet and others added 4 commits January 23, 2026 15:14
…ipe state

Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet negvet requested review from cyanguwa and timmoon10 January 23, 2026 15:32
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 23, 2026

Greptile Summary

This PR introduces QuantizerRole, a frozen dataclass that enables fine-grained, per-module/per-tensor quantization control through a semantic role-based dispatch mechanism.

Key Changes:

  • Added QuantizerRole dataclass with fields: module_type (e.g., "linear", "grouped_linear"), tensor_type (e.g., "input", "weight", "grad_output"), and name (instance name)
  • All TE modules (Linear, GroupedLinear, LayerNormLinear, LayerNormMLP, DotProductAttention) now implement get_quantizer_roles() to emit semantic roles
  • CustomRecipe.qfactory now receives QuantizerRole objects instead of strings, enabling sophisticated dispatch logic
  • Added output_quantizer_role and grad_input_quantizer_role properties to base module for configuring consumer identity
  • Provided reference factory implementations in quantization_recipes_base.py that mirror built-in recipes
  • Added example factories demonstrating per-module quantization targeting (e.g., NVFP4 for Linear, MXFP8 for GroupedLinear)
  • Comprehensive test coverage including factory equivalence tests and fine-grained targeting validation

API Design:

  • Output and grad_input quantizer slots default to None (unknown consumer), allowing factories to provide fallback behavior
  • Modules can override these via setter properties when consumer identity is known (e.g., MHA sets DPA roles for QKV outputs)
  • Factories inspect only the role fields they care about, providing flexibility for different dispatch strategies

Confidence Score: 5/5

  • This PR is safe to merge with comprehensive test coverage and well-architected API design
  • The implementation is thorough, with all modules properly updated to emit QuantizerRole objects. Test coverage is comprehensive, including factory equivalence tests that validate bit-identical results against built-in recipes. The API is clearly documented and marked as experimental. Previous review concerns about string parsing have been resolved.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantization.py Introduces QuantizerRole dataclass (module_type, tensor_type, name) and updates CustomRecipeState to accept roles parameter
transformer_engine/pytorch/module/base.py Adds output_quantizer_role and grad_input_quantizer_role properties to base module, implements get_quantizer_roles() method
transformer_engine/pytorch/module/linear.py Implements get_quantizer_roles() returning Linear-specific roles (input, weight, grad_output) with module_type="linear"
transformer_engine/pytorch/ops/op.py Adds get_quantizer_roles() abstract method to BasicOperation, passes roles to RecipeState.create()
transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py New file with reference factory implementations (current_scaling, mxfp8, float8_block_scaling, nvfp4) that mirror built-in recipes
tests/pytorch/test_custom_recipe.py Updated tests to use QuantizerRole instead of strings, added comprehensive factory equivalence tests
transformer_engine/common/recipe/init.py Updated CustomRecipe docstring to document QuantizerRole parameter instead of string roles

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[TE Module/Op] -->|emits| B[QuantizerRole]
    B -->|module_type<br/>tensor_type<br/>name| C{CustomRecipe<br/>qfactory}
    C -->|dispatches based<br/>on role fields| D[Quantizer Instance]
    
    subgraph "QuantizerRole Fields"
    B1[module_type:<br/>linear, grouped_linear, dpa]
    B2[tensor_type:<br/>input, weight, grad_output]
    B3[name:<br/>qkv, proj, fc1, fc2]
    end
    
    subgraph "Module Examples"
    M1[Linear] -->|get_quantizer_roles| B
    M2[GroupedLinear] -->|get_quantizer_roles| B
    M3[LayerNormMLP] -->|get_quantizer_roles| B
    M4[DotProductAttention] -->|get_quantizer_roles| B
    end
    
    subgraph "Quantizer Factories"
    C -->|role-based dispatch| F1[NVFP4Quantizer]
    C -->|role-based dispatch| F2[MXFP8Quantizer]
    C -->|role-based dispatch| F3[Float8CurrentScalingQuantizer]
    C -->|role-based dispatch| F4[Float8BlockQuantizer]
    end
    
    style B fill:#e1f5ff
    style C fill:#fff4e6
    style D fill:#e8f5e9
Loading

Last reviewed commit: 343f653

@greptile-apps

This comment was marked as off-topic.

Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: Evgeny <etsykunov@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Signed-off-by: Evgeny <etsykunov@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this design is quite clean and generalizable.

Comment on lines 1320 to 1329
base = [
QuantizerRole(module_type="linear", tensor_type="input", name=name),
QuantizerRole(module_type="linear", tensor_type="weight", name=name),
QuantizerRole(module_type="linear", tensor_type="output", name=name),
]
else:
base = [
QuantizerRole(module_type="linear", tensor_type="grad_output", name=name),
QuantizerRole(module_type="linear", tensor_type="grad_input", name=name),
]
Copy link
Collaborator

@timmoon10 timmoon10 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"output" and "grad_input" roles don't make sense. In reality, we are implicitly assuming that the tensor will be consumed by another linear-like layer.

Suggested change
base = [
QuantizerRole(module_type="linear", tensor_type="input", name=name),
QuantizerRole(module_type="linear", tensor_type="weight", name=name),
QuantizerRole(module_type="linear", tensor_type="output", name=name),
]
else:
base = [
QuantizerRole(module_type="linear", tensor_type="grad_output", name=name),
QuantizerRole(module_type="linear", tensor_type="grad_input", name=name),
]
base = [
QuantizerRole(module_type="linear", tensor_type="input", name=name),
QuantizerRole(module_type="linear", tensor_type="weight", name=name),
QuantizerRole(module_type="linear", tensor_type="input", name=name),
]
else:
base = [
QuantizerRole(module_type="linear", tensor_type="grad_output", name=name),
QuantizerRole(module_type="linear", tensor_type="grad_output", name=name),
]

Alternatively, if we want to use the output in FP8 DPA, the right role would be module_type="dpa" and module_type="input". We should probably make this configurable. I kind of like that this design is exposing the hidden assumptions we've been making.

Copy link
Collaborator Author

@negvet negvet Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree about "output" and "grad_input" roles. Setting roles for those slots to None (the safest) and enabling the configuration. Also configured it in MHA.

Comment on lines 310 to 314
assert counts["input"] == 1
assert counts["weight"] == 1
assert counts["output"] == 1
assert counts["grad_output"] == 1
assert counts["grad_input"] == 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert counts["input"] == 1
assert counts["weight"] == 1
assert counts["output"] == 1
assert counts["grad_output"] == 1
assert counts["grad_input"] == 1
assert counts["input"] == 2
assert counts["weight"] == 1
assert counts["output"] == 0
assert counts["grad_output"] == 2
assert counts["grad_input"] == 0

negvet and others added 2 commits February 20, 2026 14:31
Signed-off-by: Evgeny Tsykunov <etsykunov@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

negvet and others added 5 commits February 20, 2026 15:05
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

15 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Comment on lines 85 to 88
def is_gemm(self) -> bool:
"""Whether this role belongs to a GEMM-based module."""
return self.module_type in self.GEMM_MODULE_TYPES

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is baking in assumptions about what formats are similar (our recent experiences with grouped tensors makes me wonder if the requirements for "linear" and "grouped_linear" will diverge in the future), and it's also not giving us that much convenience.

Suggested change
def is_gemm(self) -> bool:
"""Whether this role belongs to a GEMM-based module."""
return self.module_type in self.GEMM_MODULE_TYPES

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, removed

Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

17 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@negvet negvet changed the title [PyTorch] Introduce semantic quantizer roles [PyTorch] Introduce quantizer roles Feb 25, 2026
Evgeny and others added 3 commits February 25, 2026 14:30
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

24 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

24 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Evgeny and others added 2 commits February 25, 2026 16:43
Signed-off-by: Evgeny <etsykunov@gmail.com>
@negvet negvet requested review from ptrendx and timmoon10 February 25, 2026 16:45
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

24 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

negvet and others added 2 commits February 26, 2026 11:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants