feat: inject TransformerEngine DotProductAttention into HF models#2011
feat: inject TransformerEngine DotProductAttention into HF models#2011HuiyingLi merged 30 commits intoNVIDIA-NeMo:mainfrom
Conversation
Adds `nemo_automodel/_transformers/te_attention.py` which post-init patches each `self_attn` module in a HuggingFace model to route the core attention kernel through TE's `DotProductAttention` (FlashAttention-3 / FP8-capable), without requiring model-specific rewrites. Usage: pass `attn_implementation="te"` to `from_pretrained` / `from_config`. The framework converts this to `"sdpa"` for HF model init, then calls `inject_te_attention()` after PEFT application and before FSDP sharding. Key design decisions: - Intercepts `F.scaled_dot_product_attention` inside each attention forward via a per-module forward wrapper (restored via try/finally). - Stores `attn_module = DotProductAttention(...)` on each `self_attn` so `_uses_te_attention()` detection is O(1) via `_te_attention_injected` flag. - Handles GQA by undoing `repeat_kv` before handing off to TE's native `num_gqa_groups` path. - Falls back to native SDPA for non-trivial `attn_mask` tensors (v1 scope). - Skips injection when packed sequences force `flash_attention_2`. Also updates `capabilities.py:_uses_te_attention` to detect injected TE (in addition to the existing `BackendConfig.attn='te'` check for custom models), and adds `inject_te_attention: bool` parameter to `apply_model_infrastructure`. 20 unit tests covering param inference, SDPA replacement logic, injection lifecycle, and forward-patch restoration on exception. Signed-off-by: khazic <khazzz1c@gmail.com>
…sliding window - _infer_attn_params now falls back to q_proj.out_features // head_dim when the module has no num_heads attribute (fixes Gemma4TextAttention) - Reads sliding_window attribute and converts to TE window_size=(W-1, 0) - Threads window_size through _create_te_dot_product_attention and _make_te_sdpa - Updates tests: standard_layout now checks window_size key; replaces test_missing_num_heads_returns_none with inference-from-proj test; adds sliding window coverage Signed-off-by: khazic <khazzz1c@gmail.com>
…thout out_features Gemma4ClippableLinear does not expose out_features, so inferring num_heads from q_proj.out_features failed silently. weight.shape[0] gives the same value and works on meta device. Signed-off-by: khazic <khazzz1c@gmail.com>
…d linears Gemma4ClippableLinear nests the actual nn.Linear under a .linear child module. _proj_out_features now tries out_features → weight.shape[0] → .linear recursion, so all three layouts resolve correctly. Signed-off-by: khazic <khazzz1c@gmail.com>
Signed-off-by: khazic <khazzz1c@gmail.com>
3e01aac to
4020024
Compare
|
/ok to test 4020024 |
|
Update: also tested on Gemma4-31B-it (single GPU, bf16) — all self_attn modules injected, forward pass OK with logits shape |
…rdcoding head_dim**-0.5 Models like Gemma4 store a pre-computed scale in module.scaling that differs from the standard 1/sqrt(head_dim). Fall back to head_dim**-0.5 only when the attribute is absent. Signed-off-by: khazic <khazzz1c@gmail.com>
|
/ok to test b024cc7 |
Signed-off-by: khazic <khazzz1c@gmail.com>
Signed-off-by: khazic <khazzz1c@gmail.com>
Signed-off-by: khazic <khazzz1c@gmail.com>
- Downgrade per-call dispatch log from INFO to DEBUG (was firing once per self_attn per forward and noticeably slowed multi-layer VLM runs). - Add module-level hit/fallback counters (te_hits, fallback_mask, fallback_scale_mismatch) plus get_te_attention_stats() / reset_te_attention_stats() helpers, and an optional periodic summary gated by AUTOMODEL_TE_STATS_EVERY (default 500, 0 disables). - te_sdpa no longer silently ignores the runtime scale argument: if a caller passes an explicit scale that disagrees with the softmax_scale baked into the TE module at construction, fall back to native SDPA and emit a one-shot warning. Previously TE would quietly use the baked-in value, which could produce wrong numerics on models that pass scale= per-call. - Short-circuit the mask fallback before the transpose/contiguous copies so the fallback path no longer pays QKV copy overhead. - Add tools/debug_te_attention.py: a single-GPU smoke test that runs a forward pass through a small HF model with TE injected and prints the dispatch counters, so fallback ratio can be validated on A100 without needing FA3 hardware. Signed-off-by: khazic <khazzz1c@gmail.com>
…vocab_size AutoModelForCausalLM fails on Gemma4 (VLM); switch to AutoModel. Also fall back to text_config.vocab_size so the script works for both pure-LM and VLM model configs without a hardcoded class name. Signed-off-by: khazic <khazzz1c@gmail.com>
Signed-off-by: khazic <khazzz1c@gmail.com>
…rnel HF models (Gemma4, Llama, etc.) always produce a 4D float attention mask via create_causal_mask / create_sliding_window_causal_mask, even when the batch has no padding. This caused 83% of self_attn calls to fall back to native SDPA, completely negating any FA3/TE speedup. Add _detect_causal_mask() which identifies the two structurally trivial cases with O(S) tensor ops (two row reductions): - Pure causal lower-triangular mask → attn_mask_type='causal', drop mask - Sliding-window causal mask → attn_mask_type='causal', window_size=(W-1,0) All batch items are checked (not just item 0) so padded batches still fall back safely to native SDPA. This converts the majority of fallback_mask calls to te_hits for standard no-padding training runs. Signed-off-by: khazic <khazzz1c@gmail.com>
Log dtype, shape, corner values, and per-row visible counts at DEBUG level so we can see exactly why mask detection returns None. Signed-off-by: khazic <khazzz1c@gmail.com>
…formers Newer transformers versions (e.g. 4.48+) emit bool attention masks (True=attend, False=masked) instead of float additive masks (-inf/0). _detect_causal_mask previously rejected all bool masks, causing 100% fallback to native SDPA on these transformers versions. Support both formats: - bool mask: True = visible, count with .sum(); corner check with not corner.item() - float mask: > -1e4 = visible (existing logic, unchanged) Signed-off-by: khazic <khazzz1c@gmail.com>
…loading Signed-off-by: khazic <khazzz1c@gmail.com>
Signed-off-by: khazic <khazzz1c@gmail.com>
A100 dispatch validation resultsValidated TE kernel dispatch on A100 (correctness only — FA3 speedup requires H100, pending). Root causes fixed in this update
Dispatch counter results (
|
|
/ok to test 6034494 |
…te test for softmax_scale torch.cuda.device(cpu_device) raises ValueError in newer PyTorch because _get_device_index no longer uses optional=True. Use contextlib.nullcontext() when the input tensor is not on CUDA. _infer_attn_params now returns softmax_scale; update test_standard_layout to include it in the expected dict. Signed-off-by: khazic <khazzz1c@gmail.com>
|
/ok to test 7eabac1 |
There was a problem hiding this comment.
we might want to remove this
Drop the standalone injection smoke-test script; the dispatch counters are now exposed via ``get_te_attention_stats()`` for callers that need them, and the training recipes already log these counters in their own runs, so the dedicated debug entry point is redundant. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
|
/ok to test 2977f23 |
…ats API Raise patch coverage above the 80% Codecov target by adding three new test classes for previously-uncovered helpers on ``te_attention.py``: - TestDetectCausalMask (13 tests): float/bfloat16/bool causal masks, sliding-window detection, and the None-return branches for unsupported dtype, wrong ndim, non-square mask, unmasked upper-right corner, masked diagonal, padded batch, unexpected first-row width, and sliding-window size mismatch. - TestProjOutFeatures (5 tests): standard Linear, weight-only custom module, wrapped linear (Gemma4ClippableLinear layout), None input, and module without any of the expected attributes. - TestStatsAPI (2 tests): reset/get round-trip and copy semantics so callers cannot mutate the internal counters. All 20 new tests pass locally alongside the existing 22. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
|
/ok to test 1e5fffa |
|
/claude review |
| def _make_te_sdpa_with_mock(self, num_heads=8, num_kv_heads=4): | ||
| te_module = mock.MagicMock() | ||
| # TE returns [B, S, H, D] | ||
| te_module.return_value = torch.zeros(2, 10, num_heads, 64) |
There was a problem hiding this comment.
The fallback_scale_mismatch path in te_sdpa is untested — it has non-trivial logic (one-shot warning via _SCALE_MISMATCH_WARNED, stats tracking, and the tolerance check abs(scale - softmax_scale) > 1e-6). Worth adding a test that passes a mismatched scale kwarg and asserts:
original_sdpais called (notte_module)get_te_attention_stats()["fallback_scale_mismatch"]increments
Note: _SCALE_MISMATCH_WARNED is a module-level global that isn't reset by reset_te_attention_stats(), so any test verifying the warning log would need to manually reset it to avoid ordering dependence.
| ---------------- | ||
| - Only causal (``is_causal=True``) and no-mask attention are supported. | ||
| Non-trivial ``attn_mask`` tensors fall back to native SDPA with a debug log. | ||
| - Sliding-window attention is not yet handled (uses unbounded left window). |
There was a problem hiding this comment.
@khazic curious whats the impact of this on gemma4 runs with TE attn ? since gemma4 uses sliding window attention
Address Claude's review comment (PR NVIDIA-NeMo#2011): the ``fallback_scale_mismatch`` path was uncovered. Add ``TestScaleMismatchFallback`` with five cases: - Mismatch triggers fallback to ``original_sdpa`` and increments the ``fallback_scale_mismatch`` counter; TE module is not called. - Matching scale uses the TE module (no fallback, counter unchanged). - Float drift below the 1e-6 tolerance is accepted. - ``scale=None`` (PyTorch default) is tolerated. - Warning emission is one-shot (``_SCALE_MISMATCH_WARNED`` flag set on first mismatch, subsequent mismatches still fall back but do not re-warn). Tests reset the flag explicitly so ordering is independent. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
|
/ok to test 8959f98 |
The original "Limitations (v1)" block was written before the mask-detection and sliding-window fixes landed. Update the header to describe the current behavior: - mask handling: causal / no_mask (attn_mask=None) and runtime detection of HF's canonical causal / causal+sliding masks via ``_detect_causal_mask``, with fallback to native SDPA for non-canonical patterns. - sliding window: read from ``module.sliding_window`` at injection time and converted to TE's ``(window_size[0], 0)`` convention, applied on both the mask-detected and ``attn_mask=None`` paths. - remaining limitation: modules that import SDPA as a local symbol rather than reading ``torch.nn.functional.scaled_dot_product_attention`` at call time won't pick up the runtime patch. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
|
/ok to test fe93303 |
| @@ -0,0 +1,582 @@ | |||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
There was a problem hiding this comment.
@khazic what do you think about moving this file inside https://github.com/NVIDIA-NeMo/Automodel/tree/main/nemo_automodel/components/attention?
Summary
nemo_automodel/_transformers/te_attention.py: a generic TE attention injection layer that monkey-patchesF.scaled_dot_product_attentioninside any HFself_attnmodule with TE'sDotProductAttention(FlashAttention-3 / FP8 capable), without requiring model-specific rewrites.q_proj/k_proj/v_proj) with GQA, sliding-window attention, and custom linear wrappers (e.g.Gemma4ClippableLinear).capabilities.py:_uses_te_attention()now also detects HF models injected via_te_attention_injectedflag, sosupports_cp,supports_sequence_packingetc. resolve correctly.infrastructure.py:apply_model_infrastructure()acceptsinject_te_attention=Trueto trigger injection after weight loading and before sharding.auto_model.py:attn_implementation="te"is now a valid option for HF models; internally loads with"sdpa"then injects TE post-init.Perf
H100 e2e training
Llama-3.1-8B-Instruct 8k seqlen mock
Qwen2.5-7B-Instruct 8k 8kseqlen mock
Gemma4 4B 4k seqlen text only mock
FA2 does NOT support Gemma4 due to 512 head_dim.
TE-FA3 dispatches to TE-unfused where FA3 doesn't support
*Need to install FA3 for the TE-FA3 path with
NVTE_FLASH_ATTN=1 NVTE_FUSED_ATTN=0Convergence
Gemma4-E4B-it MedPix VLM — 50-step convergence on shipped gemma4_4b_te.yaml
HF-FA2 crashed (head_dim=512 > FA cap) — expected, same as text-only.
Loss convergence (SDPA vs TE, with TE using cuDNN-Fused on all 70 self_attn modules)