Invalidate HookRegistry child-registries cache on enable/disable cache#14093
Invalidate HookRegistry child-registries cache on enable/disable cache#14093SuryanshSS1011 wants to merge 2 commits into
Conversation
| assert registry.get_hook("stateful_add_hook").increment == 1 | ||
| assert torch.allclose(output1, output2) | ||
|
|
||
| def test_child_registries_cache_invalidation(self): |
There was a problem hiding this comment.
Should there also be testing when the context path is exercised, as reported in #14037?
There was a problem hiding this comment.
Yes, that is a good call!
I've added test_cache_context_after_enable_cache_with_prior_context, which reproduces the issue end-to-end: it enters cache_context() before enable_cache() on a small FluxTransformer2DModel and runs a cached forward. It raises the original ValueError: No context is set without the fix and passes with it. Also fixed the check_code_quality failure (a docstring line-length restyle).
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
42f7a74 to
424f752
Compare
There was a problem hiding this comment.
🤗 Serge says:
Clean, well-targeted fix for #14037. The stale-cache diagnosis is accurate: _get_child_registries() memoizes the child-registry walk and enable_cache/disable_cache change which modules carry a _diffusers_hook, so invalidating after hooks are added/removed is the right correction.
Correctness
invalidate_child_registries_cache()walks the full tree and resets_child_registries_cache = None, which_get_child_registries()correctly treats as "rebuild on next use". Clearing every registry in the subtree (not just the root) is right, since a child registry can also appear in an ancestor's cache.- Invalidation is placed in
enable_cache/disable_cacheon the root module rather than at the true source (register_hook/remove_hook), which the author calls out. Given a child registry can't reach a stale ancestor cache, this is a reasonable and minimal choice that covers the reported path for every cache technique. - The
HookRegistryimport added toenable_cache/disable_cacheand the top-levelFirstBlockCacheConfig/FluxTransformer2DModelimports in the test are all valid exports.
Tests
- Both the unit-level (
test_child_registries_cache_invalidation) and end-to-end (test_cache_context_after_enable_cache_with_prior_context) tests exercise the fix and match the failure described in the issue. They use small CPU-friendly configs consistent with the rest of the file.
Matches the PR description. No blocking issues.
serge v0.1.0 · model: claude-opus-4-8 · 8 LLM turns · 9 tool calls · 36.9s · 145457 in / 2138 out tokens
| # Warmup pass inside a cache_context() while caching is disabled, then enable caching. | ||
| with torch.no_grad(), model.cache_context("cond"): | ||
| model(**inputs) | ||
| model.enable_cache(FirstBlockCacheConfig(threshold=0.1)) |
There was a problem hiding this comment.
Help me understand why would this be a practical flow though?
Like with torch.no_grad(), model.cache_context("cond"): isn't doing anything from the perspectives of caching. So, does this reflect how caching is to be used actually in practice?
There was a problem hiding this comment.
So our flow here is enabling caching on a transformer that has already been called once inside cache_context().
The pipelines call cache_context() unconditionally around every transformer forward. For example pipeline_qwenimage.py does with self.transformer.cache_context("cond"): on every denoise step regardless of whether caching is enabled. So the first cache_context() in my test mirrors a normal pipeline run. It's a no-op for caching but it still builds HookRegistry._child_registries_cache. If the user then calls enable_cache(...) and runs again, which is a natural "run once, then turn on caching to speed up the next run" workflow, the second cache_context() hits that stale cache and the block hooks never get a context, which raises ValueError: No context is set.
To confirm it's not contrived, calling enable_cache() before any run works fine. The bug only appears when a cache_context() pass precedes enable_cache(). The test reduces that to the minimal repro, using direct model(**inputs) calls in place of a full pipeline loop so it stays CPU-only and fast, but the ordering is the same one a real pipeline produces.
We can reshape the test to go through an actual pipeline call if you'd prefer it to read more explicitly as the real flow.
What does this PR do?
Fixes #14037.
HookRegistry._get_child_registries()caches the child-module registries it finds by walkingnamed_modules(), and never invalidates that cache. Butenable_cache()/disable_cache()add and remove block-level hooks, changing which modules carry a_diffusers_hook. Ifcache_context()is first entered while no block hooks exist (e.g. a warmup pass with caching disabled), the parent registry caches an incomplete child list. A laterenable_cache(FirstBlockCacheConfig(...))registers block hooks, but_set_context()still iterates the stale cache, so the new blockStateManagers never receive a context and the next cached forward raises:ValueError: No context is set. Please set a context before retrieving the state.This adds
HookRegistry.invalidate_child_registries_cache(), which clears the cached list across the module tree, and calls it fromenable_cache()anddisable_cache()after hooks are added/removed.The staleness originates in
register_hook/remove_hook, but those run on the child block registries, which can't reach the parent registry whose cache is stale.enable_cache/disable_cacheoperate on the root module, so invalidating there covers the reported scenario for every cache technique. Happy to move it intoregister_hook/remove_hookinstead if you'd prefer it lower down.The self-contained CPU reproduction from the issue passes after the fix, and a regression test is added in
tests/hooks/test_hooks.py.Before submitting
.ai/review-rules.md?Who can review?
@DN6 @sayakpaul