Fix torchao group offloading with use_stream=True#14112
Open
Sunt-ing wants to merge 1 commit into
Open
Conversation
9 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes #13281
This is the TorchAO
use_stream=Truehalf split out from #14038, so that #14038 can stay focused on the quanto issue.The streamed group-offload path keeps a CPU copy of each tensor and tries to pin that copy before transferring a group back to the accelerator. For TorchAO tensor subclasses,
_to_cpu()already has to calltensor.cpu()instead oftensor.data.cpu(), but the stream path still callspin_memory()andis_pinned()on the resulting subclass tensor.AffineQuantizedTensordoes not implement those pinning ops, soenable_group_offload(..., use_stream=True)fails before the group can be onloaded.This PR skips the pinning step for TorchAO tensors in the stream CPU cache. Plain tensors still use pinned memory, and the existing non-stream TorchAO swap path is unchanged.
Tests
Environment: NVIDIA RTX 4090,
torch==2.8.0+cu128,torchao==0.17.0.Pipeline before/after repro
The before checkout is this branch with the patch reversed. The after checkout is this branch. The script uses the public tiny Flux pipeline, quantizes the transformer with TorchAO int8 weight-only quantization, enables
pipe.transformer.enable_group_offload(..., use_stream=True), moves the remaining modules to CUDA, and runspipe(...).Before:
After:
Regression test:
Before submitting
.ai/review-rules.md?Who can review?
cc @sayakpaul