-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Invalidate HookRegistry child-registries cache on enable/disable cache #14093
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| import pytest | ||
| import torch | ||
|
|
||
| from diffusers import FirstBlockCacheConfig, FluxTransformer2DModel | ||
| from diffusers.hooks import HookRegistry, ModelHook | ||
| from diffusers.training_utils import free_memory | ||
| from diffusers.utils.logging import get_logger | ||
|
|
@@ -200,6 +201,26 @@ def test_stateful_hook(self): | |
| assert registry.get_hook("stateful_add_hook").increment == 1 | ||
| assert torch.allclose(output1, output2) | ||
|
|
||
| def test_child_registries_cache_invalidation(self): | ||
| # Unit-level part of the regression test for | ||
| # https://github.com/huggingface/diffusers/issues/14037: the parent registry caches its | ||
| # child registries, so a hook registered on a child block after the cache was built is | ||
| # invisible to the parent until the cache is invalidated. | ||
| parent = HookRegistry.check_if_exists_or_initialize(self.model) | ||
|
|
||
| # Build the parent's child-registry cache while no block carries a hook yet. | ||
| assert parent._get_child_registries() == [] | ||
|
|
||
| # Register a hook on a child block. The parent's cached (empty) list is now stale. | ||
| block = self.model.blocks[0] | ||
| child = HookRegistry.check_if_exists_or_initialize(block) | ||
| child.register_hook(AddHook(1), "add_hook") | ||
| assert parent._get_child_registries() == [] # still stale before invalidation | ||
|
|
||
| # Invalidating across the tree makes the new child registry reachable from the parent. | ||
| parent.invalidate_child_registries_cache() | ||
| assert child in parent._get_child_registries() | ||
|
|
||
| def test_inference(self): | ||
| registry = HookRegistry.check_if_exists_or_initialize(self.model) | ||
| registry.register_hook(AddHook(1), "add_hook") | ||
|
|
@@ -372,3 +393,46 @@ def test_invocation_order_stateful_last(self): | |
| .replace("\n", "") | ||
| ) | ||
| assert output == expected_invocation_order_log | ||
|
|
||
|
|
||
| def test_cache_context_after_enable_cache_with_prior_context(): | ||
| # End-to-end regression test for https://github.com/huggingface/diffusers/issues/14037. Entering | ||
| # cache_context() before enable_cache() builds the model's child-registry cache without the block | ||
| # hooks. enable_cache() then registers FirstBlockCache hooks on the blocks; if the cache is not | ||
| # invalidated, _set_context() iterates the stale list and the new block hooks never receive a | ||
| # context, so the next cached forward raises "No context is set". | ||
| torch.manual_seed(0) | ||
| heads, head_dim = 2, 16 | ||
| hidden = heads * head_dim | ||
| model = FluxTransformer2DModel( | ||
| patch_size=1, | ||
| in_channels=hidden, | ||
| num_layers=2, | ||
| num_single_layers=2, | ||
| attention_head_dim=head_dim, | ||
| num_attention_heads=heads, | ||
| joint_attention_dim=32, | ||
| pooled_projection_dim=16, | ||
| guidance_embeds=False, | ||
| axes_dims_rope=(2, 6, 8), | ||
| ).eval() | ||
|
|
||
| img_seq_len, txt_seq_len = 8, 4 | ||
| inputs = { | ||
| "hidden_states": torch.randn(1, img_seq_len, hidden), | ||
| "encoder_hidden_states": torch.randn(1, txt_seq_len, 32), | ||
| "pooled_projections": torch.randn(1, 16), | ||
| "timestep": torch.tensor([1.0]), | ||
| "img_ids": torch.zeros(img_seq_len, 3), | ||
| "txt_ids": torch.zeros(txt_seq_len, 3), | ||
| "return_dict": False, | ||
| } | ||
|
|
||
| # Warmup pass inside a cache_context() while caching is disabled, then enable caching. | ||
| with torch.no_grad(), model.cache_context("cond"): | ||
| model(**inputs) | ||
| model.enable_cache(FirstBlockCacheConfig(threshold=0.1)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Help me understand why would this be a practical flow though? Like
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So our flow here is enabling caching on a transformer that has already been called once inside The pipelines call To confirm it's not contrived, calling We can reshape the test to go through an actual pipeline call if you'd prefer it to read more explicitly as the real flow.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah let's do that |
||
|
|
||
| # Previously raised "No context is set"; the cache invalidation in enable_cache() fixes it. | ||
| with torch.no_grad(), model.cache_context("cond"): | ||
| model(**inputs) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should there also be testing when the context path is exercised, as reported in #14037?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that is a good call!
I've added
test_cache_context_after_enable_cache_with_prior_context, which reproduces the issue end-to-end: it enterscache_context()beforeenable_cache()on a small FluxTransformer2DModel and runs a cached forward. It raises the originalValueError: No context is set without the fix and passes with it. Also fixed thecheck_code_qualityfailure (a docstring line-length restyle).