Skip to content

remove str option for quantization config in torchao#13291

Open
howardzhang-cv wants to merge 7 commits intohuggingface:mainfrom
howardzhang-cv:update_torchao_test
Open

remove str option for quantization config in torchao#13291
howardzhang-cv wants to merge 7 commits intohuggingface:mainfrom
howardzhang-cv:update_torchao_test

Conversation

@howardzhang-cv
Copy link
Contributor

What does this PR do?

Remove the deprecated string-based quant_type path from TorchAoConfig, requiring AOBaseConfig instances instead.

  • TorchAoConfig.init now only accepts AOBaseConfig subclass instances (e.g. Int8WeightOnlyConfig()) and raises TypeError for strings
  • Deleted ~200 lines of dead code: _get_torchao_quant_type_to_method, _is_xpu_or_cuda_capability_atleast_8_9, TorchAoJSONEncoder, and all string-parsing branches in post_init, to_dict, from_dict, get_apply_tensor_subclass
  • Simplified torchao_quantizer.py: removed string-based branches in update_torch_dtype, adjust_target_dtype, get_cuda_warm_up_factor; fixed is_trainable which would crash on AOBaseConfig objects
  • Converted all test cases from string quant types to their AOBaseConfig equivalents; removed test_floatx_quantization (no replacement for fpx_weight_only)
  • Updated docs to show only AOBaseConfig-based usage

Testing

python -m pytest tests/quantization/torchao/test_torchao.py -xvs

Who can review?

@sayakpaul

Comment on lines +217 to +227
if isinstance(quant_type, AOBaseConfig):
# Extract size digit using fuzzy match on the class name
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)

# Map the extracted digit to appropriate dtype
if size_digit == "4":
return CustomDtype.INT4
elif quant_type == "uintx_weight_only":
return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
elif quant_type.startswith("uint"):
return {
1: torch.uint1,
2: torch.uint2,
3: torch.uint3,
4: torch.uint4,
5: torch.uint5,
6: torch.uint6,
7: torch.uint7,
}[int(quant_type[4])]
elif quant_type.startswith("float") or quant_type.startswith("fp"):
return torch.bfloat16

elif is_torchao_version(">", "0.9.0"):
from torchao.core.config import AOBaseConfig

quant_type = self.quantization_config.quant_type
if isinstance(quant_type, AOBaseConfig):
# Extract size digit using fuzzy match on the class name
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)

# Map the extracted digit to appropriate dtype
if size_digit == "4":
return CustomDtype.INT4
else:
# Default to int8
return torch.int8
else:
# Default to int8
return torch.int8

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems a bit fragile, I think it's from transformers originally, not sure if this is still needed in transformers though, it might have been refactored after 5.0 update

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry might be a dumb question, what part are you referring to that's from transformers originally? Is it the entire adjust_target_dtype function?

f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
)
if is_torchao_version("<=", "0.9.0"):
Copy link

@jerryzh168 jerryzh168 Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

separate PR: I feel we should just have a single assertion for torchao to be a relatively recent version (e.g. 0.15) and remove all these version checks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually yeah I was going to ask you about that as well. There's a couple version checks scattered around right now. Would be cleaner to just remove all of them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we should mandate a minimum version requirement here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to wait to do this in a separate PR? I changed and set it to 0.9.0 because that's when AOBaseConfig was supported. Moving to 0.15.0 might be cleaner in a separate PR in case we need to revert for whatever reason?

@howardzhang-cv howardzhang-cv marked this pull request as ready for review March 19, 2026 21:43
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for starting this work!

I think we can merge this fairly soon.

f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
)
if is_torchao_version("<=", "0.9.0"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we should mandate a minimum version requirement here.

Comment on lines +478 to +479
if not isinstance(self.quant_type, AOBaseConfig):
raise TypeError(f"quant_type must be an AOBaseConfig instance, got {type(self.quant_type).__name__}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes cool!

@sayakpaul
Copy link
Member

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Mar 20, 2026

Style bot fixed some files and pushed the changes.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this. I will run the tests on my end to see if we didn't miss anything.

We need to update the test suite here as well:

class TorchAoConfigMixin:

logger.warning(
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we not implement an equivalent of

        if isinstance(quant_type, str) and (quant_type.startswith("int") or quant_type.startswith("uint")):
            if torch_dtype is not None and torch_dtype != torch.bfloat16:
                logger.warning(
                    f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
                    f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`."
                )

?

Copy link
Contributor Author

@howardzhang-cv howardzhang-cv Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I understand, I think this is outdated and not needed anymore? @jerryzh168 maybe can you confirm?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@howardzhang-cv oh I meant this is not needed for transformers (after their 5.0 update). but for diffusers I think we should preserve the behavior, until they do a similar update like transformers, wondering if this is planned? @sayakpaul

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but for diffusers I think we should preserve the behavior, until they do a similar update like transformers,

What is this similar update?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree with Jerry. Let's preserve the behaviour.

@sayakpaul
Copy link
Member

@howardzhang-cv thanks for the updates but I guess the following are remaining to be updated?

class TorchAoConfigMixin:

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the further updates.

Will wait for @jerryzh168 to provide an update on https://github.com/huggingface/diffusers/pull/13291/changes#r2977963964

| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` |

Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations.
| **Category** | **Configuration Classes** |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's practically nothing preventing the users from using the configs supported through TorchAO and they might not be limited to the ones we're including the in following table. For example, we can use the more recent NVFP4 and MXFP8 schemes (their respective config classes) here as well.

So, how about we provide examples to the popular config classes like Int8DynamicActivationInt4WeightConfig, Int8WeightOnlyConfig, and Float8DynamicActivationFloat8WeightConfig (with hyperlinks) and then provide a link to available config options (this will be a TorchAO doc link) for the users to explore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@howardzhang-cv
Copy link
Contributor Author

@howardzhang-cv thanks for the updates but I guess the following are remaining to be updated?

class TorchAoConfigMixin:

Nice catch, I updated it here to fit the new option

@howardzhang-cv
Copy link
Contributor Author

@sayakpaul I made some more minor changes:

  1. Fixed the documentation as per your comment
  2. Fixed some of the broken links in the example popular configs and removed some no longer supported ones
  3. Fixed some documentation the cogvideox markdown file
  4. Fixed the quantization.py file you were mentioning in the comment above.

@jerryzh168 when you get a chance, can you take a look at the src/diffusers/quantizers/torchao/torchao_quantizer.py comment above when you get a chance?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Left some further minor comments. Major comment being the use of version=2 where possible.

pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="torchao",
quant_kwargs={"quant_type": "int8wo"},
quant_kwargs={"quant_type": Int8WeightOnlyConfig()},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we promote the use of version=2?


Refer to the [official torchao documentation](https://docs.pytorch.org/ao/stable/index.html) for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
| **Category** | **Configuration Classes** |
|---|---|
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is cool!

logger.warning(
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree with Jerry. Let's preserve the behaviour.

Comment on lines -66 to -67
if is_torchao_version(">=", "0.9.0"):
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am guessing we're already fixing the minimum version to be 0.9.0?

@sayakpaul
Copy link
Member

I ran the tests and I am getting (pytest tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo):

FAILED tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_quantization_num_parameters[int4wo] - ImportError: Requires mslk >= 1.0.0
FAILED tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_quantization_memory_footprint[int4wo] - ImportError: Requires mslk >= 1.0.0
FAILED tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_quantization_inference[int4wo] - ImportError: Requires mslk >= 1.0.0
FAILED tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_quantization_inference[int8dq] - RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16
FAILED tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_dequantize - NotImplementedError: QuantizationMethod.TORCHAO has no implementation of `dequantize`, please raise an issue on GitHub.

But these are failing on main as well. The mslk error seems unexpected. We didn't face it previously. I am on torchao 0.17.0.dev20260320+cu128 and torch 2.12.0.dev20260319+cu128.

I will look into the last two separately.

@sayakpaul
Copy link
Member

Some of the failing tests are also relevant to this PR, let's fix them. Example:
https://github.com/huggingface/diffusers/actions/runs/23516839193/job/68468358140?pr=13291

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants