remove str option for quantization config in torchao#13291
remove str option for quantization config in torchao#13291howardzhang-cv wants to merge 7 commits intohuggingface:mainfrom
Conversation
| if isinstance(quant_type, AOBaseConfig): | ||
| # Extract size digit using fuzzy match on the class name | ||
| config_name = quant_type.__class__.__name__ | ||
| size_digit = fuzzy_match_size(config_name) | ||
|
|
||
| # Map the extracted digit to appropriate dtype | ||
| if size_digit == "4": | ||
| return CustomDtype.INT4 | ||
| elif quant_type == "uintx_weight_only": | ||
| return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8) | ||
| elif quant_type.startswith("uint"): | ||
| return { | ||
| 1: torch.uint1, | ||
| 2: torch.uint2, | ||
| 3: torch.uint3, | ||
| 4: torch.uint4, | ||
| 5: torch.uint5, | ||
| 6: torch.uint6, | ||
| 7: torch.uint7, | ||
| }[int(quant_type[4])] | ||
| elif quant_type.startswith("float") or quant_type.startswith("fp"): | ||
| return torch.bfloat16 | ||
|
|
||
| elif is_torchao_version(">", "0.9.0"): | ||
| from torchao.core.config import AOBaseConfig | ||
|
|
||
| quant_type = self.quantization_config.quant_type | ||
| if isinstance(quant_type, AOBaseConfig): | ||
| # Extract size digit using fuzzy match on the class name | ||
| config_name = quant_type.__class__.__name__ | ||
| size_digit = fuzzy_match_size(config_name) | ||
|
|
||
| # Map the extracted digit to appropriate dtype | ||
| if size_digit == "4": | ||
| return CustomDtype.INT4 | ||
| else: | ||
| # Default to int8 | ||
| return torch.int8 | ||
| else: | ||
| # Default to int8 | ||
| return torch.int8 |
There was a problem hiding this comment.
this seems a bit fragile, I think it's from transformers originally, not sure if this is still needed in transformers though, it might have been refactored after 5.0 update
There was a problem hiding this comment.
Sorry might be a dumb question, what part are you referring to that's from transformers originally? Is it the entire adjust_target_dtype function?
| f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the " | ||
| f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." | ||
| ) | ||
| if is_torchao_version("<=", "0.9.0"): |
There was a problem hiding this comment.
separate PR: I feel we should just have a single assertion for torchao to be a relatively recent version (e.g. 0.15) and remove all these version checks
There was a problem hiding this comment.
Actually yeah I was going to ask you about that as well. There's a couple version checks scattered around right now. Would be cleaner to just remove all of them.
There was a problem hiding this comment.
Yeah we should mandate a minimum version requirement here.
There was a problem hiding this comment.
Do we want to wait to do this in a separate PR? I changed and set it to 0.9.0 because that's when AOBaseConfig was supported. Moving to 0.15.0 might be cleaner in a separate PR in case we need to revert for whatever reason?
ed295a7 to
d49eb2e
Compare
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks a lot for starting this work!
I think we can merge this fairly soon.
| f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the " | ||
| f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." | ||
| ) | ||
| if is_torchao_version("<=", "0.9.0"): |
There was a problem hiding this comment.
Yeah we should mandate a minimum version requirement here.
| if not isinstance(self.quant_type, AOBaseConfig): | ||
| raise TypeError(f"quant_type must be an AOBaseConfig instance, got {type(self.quant_type).__name__}") |
|
@bot /style |
|
Style bot fixed some files and pushed the changes. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| logger.warning( | ||
| f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but " | ||
| f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`." | ||
| ) |
There was a problem hiding this comment.
Should we not implement an equivalent of
if isinstance(quant_type, str) and (quant_type.startswith("int") or quant_type.startswith("uint")):
if torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning(
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`."
)?
There was a problem hiding this comment.
From what I understand, I think this is outdated and not needed anymore? @jerryzh168 maybe can you confirm?
There was a problem hiding this comment.
@howardzhang-cv oh I meant this is not needed for transformers (after their 5.0 update). but for diffusers I think we should preserve the behavior, until they do a similar update like transformers, wondering if this is planned? @sayakpaul
There was a problem hiding this comment.
but for diffusers I think we should preserve the behavior, until they do a similar update like transformers,
What is this similar update?
There was a problem hiding this comment.
Yes, I agree with Jerry. Let's preserve the behaviour.
|
@howardzhang-cv thanks for the updates but I guess the following are remaining to be updated? |
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks for the further updates.
Will wait for @jerryzh168 to provide an update on https://github.com/huggingface/diffusers/pull/13291/changes#r2977963964
| | **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` | | ||
|
|
||
| Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations. | ||
| | **Category** | **Configuration Classes** | |
There was a problem hiding this comment.
There's practically nothing preventing the users from using the configs supported through TorchAO and they might not be limited to the ones we're including the in following table. For example, we can use the more recent NVFP4 and MXFP8 schemes (their respective config classes) here as well.
So, how about we provide examples to the popular config classes like Int8DynamicActivationInt4WeightConfig, Int8WeightOnlyConfig, and Float8DynamicActivationFloat8WeightConfig (with hyperlinks) and then provide a link to available config options (this will be a TorchAO doc link) for the users to explore?
There was a problem hiding this comment.
https://docs.pytorch.org/ao/main/workflows/inference.html#inference-workflows will be useful I think
Nice catch, I updated it here to fit the new option |
|
@sayakpaul I made some more minor changes:
@jerryzh168 when you get a chance, can you take a look at the src/diffusers/quantizers/torchao/torchao_quantizer.py comment above when you get a chance? |
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks! Left some further minor comments. Major comment being the use of version=2 where possible.
| pipeline_quant_config = PipelineQuantizationConfig( | ||
| quant_backend="torchao", | ||
| quant_kwargs={"quant_type": "int8wo"}, | ||
| quant_kwargs={"quant_type": Int8WeightOnlyConfig()}, |
There was a problem hiding this comment.
Perhaps we promote the use of version=2?
|
|
||
| Refer to the [official torchao documentation](https://docs.pytorch.org/ao/stable/index.html) for a better understanding of the available quantization methods and the exhaustive list of configuration options available. | ||
| | **Category** | **Configuration Classes** | | ||
| |---|---| |
| logger.warning( | ||
| f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but " | ||
| f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`." | ||
| ) |
There was a problem hiding this comment.
Yes, I agree with Jerry. Let's preserve the behaviour.
| if is_torchao_version(">=", "0.9.0"): | ||
| pass |
There was a problem hiding this comment.
I am guessing we're already fixing the minimum version to be 0.9.0?
|
I ran the tests and I am getting ( FAILED tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_quantization_num_parameters[int4wo] - ImportError: Requires mslk >= 1.0.0
FAILED tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_quantization_memory_footprint[int4wo] - ImportError: Requires mslk >= 1.0.0
FAILED tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_quantization_inference[int4wo] - ImportError: Requires mslk >= 1.0.0
FAILED tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_quantization_inference[int8dq] - RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16
FAILED tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_dequantize - NotImplementedError: QuantizationMethod.TORCHAO has no implementation of `dequantize`, please raise an issue on GitHub.But these are failing on I will look into the last two separately. |
|
Some of the failing tests are also relevant to this PR, let's fix them. Example: |
What does this PR do?
Remove the deprecated string-based quant_type path from TorchAoConfig, requiring AOBaseConfig instances instead.
Testing
python -m pytest tests/quantization/torchao/test_torchao.py -xvsWho can review?
@sayakpaul