Skip to content

Fix group offloading for quanto-quantized models#14038

Open
Sunt-ing wants to merge 2 commits into
huggingface:mainfrom
Sunt-ing:0
Open

Fix group offloading for quanto-quantized models#14038
Sunt-ing wants to merge 2 commits into
huggingface:mainfrom
Sunt-ing:0

Conversation

@Sunt-ing

@Sunt-ing Sunt-ing commented Jun 22, 2026

Copy link
Copy Markdown

What does this PR do?

Fixes #12610

Group offloading moves a group's parameters between CPU and the accelerator by reassigning param.data:

param.data = source_tensor.to(device)

This is correct for plain tensors but wrong for quanto tensor subclasses. A quanto WeightQBytesTensor stores the real payload in internal tensors such as _data and _scale; replacing .data only swaps the outer wrapper and leaves those internal tensors on the source device. The next matmul then fails with mat2 is on cpu, different from cuda:0.

#13276 fixed the same subclass-storage issue for TorchAO tensors by swapping the full tensor subclass instead of assigning .data, but quanto tensors still fall through to the plain tensor path. This PR adds the corresponding quanto path and keeps the TorchAO stream fix split out in #14112.

Changes

  • Detect quanto QTensor parameters without importing optimum-quanto unless it is installed.
  • Use torch.utils.swap_tensors for quanto onload/offload instead of assigning .data.
  • Restore and record streams for quanto internal tensors using the subclass tensor names from __tensor_flatten__().
  • Skip pinned-memory conversion for quanto tensors, since pin_memory() does not preserve the quanto subclass.

Tests

Environment: NVIDIA RTX 4090, torch==2.8.0+cu128, optimum-quanto==0.2.7.

Reproduction and before/after

Minimal standalone repro for #12610:

import torch
from diffusers import UNet2DConditionModel
from diffusers.hooks import apply_group_offloading
from optimum.quanto import quantize, freeze, qint8

m = UNet2DConditionModel.from_pretrained(
    "hf-internal-testing/tiny-stable-diffusion-pipe", subfolder="unet"
).to(torch.float32).eval()
quantize(m, weights=qint8)
freeze(m)
apply_group_offloading(
    m,
    onload_device=torch.device("cuda"),
    offload_device=torch.device("cpu"),
    offload_type="leaf_level",
)
x = torch.randn(2, m.config.in_channels, m.config.sample_size, m.config.sample_size, device="cuda")
t = torch.tensor([10, 10], device="cuda")
e = torch.randn(2, 4, m.config.cross_attention_dim, device="cuda")
with torch.no_grad():
    m(x, t, e)

On main, this fails with:

RuntimeError: mat2 is on cpu, different from cuda:0

With this PR, quanto group offload matches the fully-on-accelerator quantized baseline across leaf_level, block_level, non-stream, use_stream, and record_stream configs. The maximum absolute difference is 0.0.

Regression tests:

python -m pytest tests/quantization/quanto/test_quanto.py::FluxTransformerInt8WeightsTest::test_group_offloading -q
python -m pytest tests/quantization/quanto/test_quanto.py::FluxTransformerFloat8WeightsTest::test_group_offloading -q

Both tests fail on main with the device mismatch and pass with this PR.

Before submitting

Who can review?

cc @sayakpaul

@github-actions github-actions Bot added fixes-issue size/M PR with diff < 200 LOC tests hooks and removed size/M PR with diff < 200 LOC labels Jun 22, 2026
@sayakpaul

Copy link
Copy Markdown
Member

Group offloading should have been fixed, though with #13276. Can you check again?

@Sunt-ing

Copy link
Copy Markdown
Author

Hi @sayakpaul, thanks. Yes, I rechecked against #13276 before opening this. #13276 makes group offloading work for torchao by swapping the subclass (_is_torchao_tensortorch.utils.swap_tensors on onload, setattr of inner tensors on the offload restore). Two cases it doesn't cover are exactly what this PR targets:

Both #12610 and #13281 are still open. I confirmed on current main (so with #13276 in) that the three tests this PR adds fail, and pass here:

main (with #13276) vs this PR
# main (fix reverted, tests kept)
quanto  FluxTransformerInt8WeightsTest::test_group_offloading    FAILED  (mat2 is on cpu, different from cuda:0)
quanto  FluxTransformerFloat8WeightsTest::test_group_offloading  FAILED  (mat2 is on cpu, different from cuda:0)
torchao TorchAoTest::test_group_offloading                       FAILED  (NotImplementedError: ... aten.is_pinned)

# with this PR
quanto  FluxTransformerInt8WeightsTest::test_group_offloading    PASSED
quanto  FluxTransformerFloat8WeightsTest::test_group_offloading  PASSED
torchao TorchAoTest::test_group_offloading                       PASSED

On approach: I deliberately mirrored the existing _is_torchao_tensor branch rather than touching it, to keep this a low-risk bug fix (_is_quanto_tensor gates on is_optimum_quanto_available() and pulls inner-tensor names from the standard __tensor_flatten__()). I also saw your note in #13276 about generalizing these utilities to swap_tensors for any subclass instead of .data. Happy to fold torchao + quanto into one generic subclass path here, or leave that as the separate follow-up you mentioned, whichever you prefer.

@sayakpaul

Copy link
Copy Markdown
Member

Can we focus on one issue at a time? Therefore, I would suggest splitting the PR into two.

@github-actions github-actions Bot added the size/M PR with diff < 200 LOC label Jul 2, 2026
@Sunt-ing Sunt-ing changed the title Fix group offloading for quanto-quantized models and the use_stream path for quantized tensor subclasses Fix group offloading for quanto-quantized models Jul 2, 2026
@Sunt-ing

Sunt-ing commented Jul 2, 2026

Copy link
Copy Markdown
Author

Thanks @sayakpaul. I split the TorchAO use_stream=True fix into #14112 and updated this PR to focus on the quanto issue in #12610.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Quanto + Group Offload causes device mismatch error (weights on cpu, mat1 on gpu)

2 participants