Skip to content

feat: inject TransformerEngine DotProductAttention into HF models#2011

Merged
HuiyingLi merged 30 commits intoNVIDIA-NeMo:mainfrom
khazic:feat/te-attention
Apr 24, 2026
Merged

feat: inject TransformerEngine DotProductAttention into HF models#2011
HuiyingLi merged 30 commits intoNVIDIA-NeMo:mainfrom
khazic:feat/te-attention

Conversation

@khazic
Copy link
Copy Markdown
Contributor

@khazic khazic commented Apr 23, 2026

Summary

  • Adds nemo_automodel/_transformers/te_attention.py: a generic TE attention injection layer that monkey-patches F.scaled_dot_product_attention inside any HF self_attn module with TE's DotProductAttention (FlashAttention-3 / FP8 capable), without requiring model-specific rewrites.
  • Supports standard Llama-style layout (q_proj/k_proj/v_proj) with GQA, sliding-window attention, and custom linear wrappers (e.g. Gemma4ClippableLinear).
  • capabilities.py: _uses_te_attention() now also detects HF models injected via _te_attention_injected flag, so supports_cp, supports_sequence_packing etc. resolve correctly.
  • infrastructure.py: apply_model_infrastructure() accepts inject_te_attention=True to trigger injection after weight loading and before sharding.
  • auto_model.py: attn_implementation="te" is now a valid option for HF models; internally loads with "sdpa" then injects TE post-init.
  • 20+ unit tests covering parameter inference, SDPA replacement, GQA unrepeat, fallback, and forward-patch restoration.

Perf

H100 e2e training

Llama-3.1-8B-Instruct 8k seqlen mock

┌──────────┬──────────┬─────────┬─────────┬───────────┐
│ backend  │ mean tps │ ± stdev │ vs SDPA │ mean loss │
├──────────┼──────────┼─────────┼─────────┼───────────┤
│ SDPA     │   69,531 │     231 │       — │    4.6300 │
├──────────┼──────────┼─────────┼─────────┼───────────┤
│ TE-Fused │   75,214 │     198 │   +8.2% │    4.6341 │
├──────────┼──────────┼─────────┼─────────┼───────────┤
│ HF-FA2   │   70,352 │     125 │   +1.2% │    4.6301 │
├──────────┼──────────┼─────────┼─────────┼───────────┤
│ TE-FA3   │   76,920 │     129 │  +10.6% │    4.6280 │
└──────────┴──────────┴─────────┴─────────┴───────────┘

Qwen2.5-7B-Instruct 8k 8kseqlen mock

┌──────────┬──────────┬─────────┬─────────┬───────────┐
│ backend  │ mean tps │ ± stdev │ vs SDPA │ mean loss │
├──────────┼──────────┼─────────┼─────────┼───────────┤
│ SDPA     │   75,349 │     251 │       — │    4.6310 │
├──────────┼──────────┼─────────┼─────────┼───────────┤
│ TE-Fused │   81,543 │     146 │   +8.2% │    4.6311 │
├──────────┼──────────┼─────────┼─────────┼───────────┤
│ HF-FA2   │   76,329 │     256 │   +1.3% │    4.6310 │
├──────────┼──────────┼─────────┼─────────┼───────────┤
│ TE-FA3   │   82,809 │     136 │   +9.9% │    4.6311 │
└──────────┴──────────┴─────────┴─────────┴───────────┘

Gemma4 4B 4k seqlen text only mock
FA2 does NOT support Gemma4 due to 512 head_dim.
TE-FA3 dispatches to TE-unfused where FA3 doesn't support

┌──────────┬──────────┬─────────┬─────────┬───────────┐
│ backend  │ mean tps │ ± stdev │ vs SDPA │ mean loss │
├──────────┼──────────┼─────────┼─────────┼───────────┤
│ SDPA     │    2,722 │      63 │       — │    4.9508 │
├──────────┼──────────┼─────────┼─────────┼───────────┤
│ TE-Fused │    2,793 │      79 │   +2.6% │    4.9304 │
├──────────┼──────────┼─────────┼─────────┼───────────┤
│ HF-FA2   │    CRASH │       — │       — │         — │
├──────────┼──────────┼─────────┼─────────┼───────────┤
│ TE-FA3   │    2,812 │      82 │   +3.3% │    4.9155 │ 
└──────────┴──────────┴─────────┴─────────┴───────────┘

*Need to install FA3 for the TE-FA3 path with NVTE_FLASH_ATTN=1 NVTE_FUSED_ATTN=0

Convergence

Gemma4-E4B-it MedPix VLM — 50-step convergence on shipped gemma4_4b_te.yaml

HF-FA2 crashed (head_dim=512 > FA cap) — expected, same as text-only.

Loss convergence (SDPA vs TE, with TE using cuDNN-Fused on all 70 self_attn modules)

┌──────┬───────────┬─────────┬──────────────┬────────────────┬──────────────┐
│ step │ SDPA loss │ TE loss │ Δ(TE − SDPA) │ SDPA grad_norm │ TE grad_norm │
├──────┼───────────┼─────────┼──────────────┼────────────────┼──────────────┤
│ 0    │    3.3541 │  3.3568 │      +0.0027 │          41.75 │        42.00 │
├──────┼───────────┼─────────┼──────────────┼────────────────┼──────────────┤
│ 10   │    2.4664 │  2.4608 │      −0.0056 │          12.44 │        12.44 │
├──────┼───────────┼─────────┼──────────────┼────────────────┼──────────────┤
│ 20   │    2.4986 │  2.5105 │      +0.0119 │          12.00 │        12.25 │
├──────┼───────────┼─────────┼──────────────┼────────────────┼──────────────┤
│ 30   │    2.3723 │  2.3693 │      −0.0030 │          14.12 │        13.62 │
├──────┼───────────┼─────────┼──────────────┼────────────────┼──────────────┤
│ 40   │    2.2567 │  2.2556 │      −0.0011 │           6.69 │         7.03 │
├──────┼───────────┼─────────┼──────────────┼────────────────┼──────────────┤
│ 49   │    2.1817 │  2.1815 │      −0.0002 │           8.88 │         8.38 │
└──────┴───────────┴─────────┴──────────────┴────────────────┴──────────────

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 23, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

khazic added 5 commits April 23, 2026 11:41
Adds `nemo_automodel/_transformers/te_attention.py` which post-init patches
each `self_attn` module in a HuggingFace model to route the core attention
kernel through TE's `DotProductAttention` (FlashAttention-3 / FP8-capable),
without requiring model-specific rewrites.

Usage: pass `attn_implementation="te"` to `from_pretrained` / `from_config`.
The framework converts this to `"sdpa"` for HF model init, then calls
`inject_te_attention()` after PEFT application and before FSDP sharding.

Key design decisions:
- Intercepts `F.scaled_dot_product_attention` inside each attention forward
  via a per-module forward wrapper (restored via try/finally).
- Stores `attn_module = DotProductAttention(...)` on each `self_attn` so
  `_uses_te_attention()` detection is O(1) via `_te_attention_injected` flag.
- Handles GQA by undoing `repeat_kv` before handing off to TE's native
  `num_gqa_groups` path.
- Falls back to native SDPA for non-trivial `attn_mask` tensors (v1 scope).
- Skips injection when packed sequences force `flash_attention_2`.

Also updates `capabilities.py:_uses_te_attention` to detect injected TE
(in addition to the existing `BackendConfig.attn='te'` check for custom
models), and adds `inject_te_attention: bool` parameter to
`apply_model_infrastructure`.

20 unit tests covering param inference, SDPA replacement logic, injection
lifecycle, and forward-patch restoration on exception.

Signed-off-by: khazic <khazzz1c@gmail.com>
…sliding window

- _infer_attn_params now falls back to q_proj.out_features // head_dim
  when the module has no num_heads attribute (fixes Gemma4TextAttention)
- Reads sliding_window attribute and converts to TE window_size=(W-1, 0)
- Threads window_size through _create_te_dot_product_attention and _make_te_sdpa
- Updates tests: standard_layout now checks window_size key; replaces
  test_missing_num_heads_returns_none with inference-from-proj test;
  adds sliding window coverage

Signed-off-by: khazic <khazzz1c@gmail.com>
…thout out_features

Gemma4ClippableLinear does not expose out_features, so inferring num_heads
from q_proj.out_features failed silently. weight.shape[0] gives the same
value and works on meta device.

Signed-off-by: khazic <khazzz1c@gmail.com>
…d linears

Gemma4ClippableLinear nests the actual nn.Linear under a .linear child
module. _proj_out_features now tries out_features → weight.shape[0] →
.linear recursion, so all three layouts resolve correctly.

Signed-off-by: khazic <khazzz1c@gmail.com>
Signed-off-by: khazic <khazzz1c@gmail.com>
@khazic khazic force-pushed the feat/te-attention branch from 3e01aac to 4020024 Compare April 23, 2026 03:41
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 4020024

@khazic
Copy link
Copy Markdown
Contributor Author

khazic commented Apr 23, 2026

Update: also tested on Gemma4-31B-it (single GPU, bf16) — all self_attn modules injected, forward pass OK with logits shape [1, 6, 262144].

…rdcoding head_dim**-0.5

Models like Gemma4 store a pre-computed scale in module.scaling that differs
from the standard 1/sqrt(head_dim). Fall back to head_dim**-0.5 only when
the attribute is absent.

Signed-off-by: khazic <khazzz1c@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test b024cc7

khazic added 11 commits April 23, 2026 17:56
Signed-off-by: khazic <khazzz1c@gmail.com>
Signed-off-by: khazic <khazzz1c@gmail.com>
Signed-off-by: khazic <khazzz1c@gmail.com>
- Downgrade per-call dispatch log from INFO to DEBUG (was firing once per
  self_attn per forward and noticeably slowed multi-layer VLM runs).
- Add module-level hit/fallback counters (te_hits, fallback_mask,
  fallback_scale_mismatch) plus get_te_attention_stats() /
  reset_te_attention_stats() helpers, and an optional periodic summary
  gated by AUTOMODEL_TE_STATS_EVERY (default 500, 0 disables).
- te_sdpa no longer silently ignores the runtime scale argument: if a
  caller passes an explicit scale that disagrees with the softmax_scale
  baked into the TE module at construction, fall back to native SDPA and
  emit a one-shot warning. Previously TE would quietly use the baked-in
  value, which could produce wrong numerics on models that pass scale=
  per-call.
- Short-circuit the mask fallback before the transpose/contiguous copies
  so the fallback path no longer pays QKV copy overhead.
- Add tools/debug_te_attention.py: a single-GPU smoke test that runs a
  forward pass through a small HF model with TE injected and prints the
  dispatch counters, so fallback ratio can be validated on A100 without
  needing FA3 hardware.

Signed-off-by: khazic <khazzz1c@gmail.com>
…vocab_size

AutoModelForCausalLM fails on Gemma4 (VLM); switch to AutoModel.
Also fall back to text_config.vocab_size so the script works for both
pure-LM and VLM model configs without a hardcoded class name.

Signed-off-by: khazic <khazzz1c@gmail.com>
…rnel

HF models (Gemma4, Llama, etc.) always produce a 4D float attention mask
via create_causal_mask / create_sliding_window_causal_mask, even when the
batch has no padding. This caused 83% of self_attn calls to fall back to
native SDPA, completely negating any FA3/TE speedup.

Add _detect_causal_mask() which identifies the two structurally trivial
cases with O(S) tensor ops (two row reductions):
- Pure causal lower-triangular mask → attn_mask_type='causal', drop mask
- Sliding-window causal mask → attn_mask_type='causal', window_size=(W-1,0)

All batch items are checked (not just item 0) so padded batches still fall
back safely to native SDPA. This converts the majority of fallback_mask
calls to te_hits for standard no-padding training runs.

Signed-off-by: khazic <khazzz1c@gmail.com>
Log dtype, shape, corner values, and per-row visible counts at DEBUG
level so we can see exactly why mask detection returns None.

Signed-off-by: khazic <khazzz1c@gmail.com>
…formers

Newer transformers versions (e.g. 4.48+) emit bool attention masks
(True=attend, False=masked) instead of float additive masks (-inf/0).
_detect_causal_mask previously rejected all bool masks, causing 100%
fallback to native SDPA on these transformers versions.

Support both formats:
- bool mask: True = visible, count with .sum(); corner check with not corner.item()
- float mask: > -1e4 = visible (existing logic, unchanged)

Signed-off-by: khazic <khazzz1c@gmail.com>
…loading

Signed-off-by: khazic <khazzz1c@gmail.com>
Signed-off-by: khazic <khazzz1c@gmail.com>
@khazic
Copy link
Copy Markdown
Contributor Author

khazic commented Apr 23, 2026

A100 dispatch validation results

Validated TE kernel dispatch on A100 (correctness only — FA3 speedup requires H100, pending).

Root causes fixed in this update

Bug Impact
HF always emits explicit causal masks → attn_mask is not None fallback on every layer TE kernel never ran; 83% of calls fell back to native SDPA
Mask detection only handled float additive masks; newer transformers emits torch.bool 100% fallback on transformers ≥ 4.48
Per-call logger.info in attention hot-path Throughput regression on multi-layer VLMs (observed -58% tps spike)
Runtime scale argument silently ignored Silent numeric divergence when caller passes explicit softmax scale

Dispatch counter results (tools/debug_te_attention.py)

Gemma4-E4B (70 self_attn modules injected, A100 single GPU):

te_sdpa calls : 168
te_hits       : 168  (100.0%)
fallback_mask :   0  (  0.0%)
fallback_scale_mismatch : 0

Gemma4-31B (87 self_attn modules injected, 8×A100 device_map=auto):

te_sdpa calls : 240
te_hits       : 240  (100.0%)
fallback_mask :   0  (  0.0%)
fallback_scale_mismatch : 0

Both models tested with and without explicit attention_mask. FA3 speedup numbers on H100 to follow.

@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 6034494

…te test for softmax_scale

torch.cuda.device(cpu_device) raises ValueError in newer PyTorch because
_get_device_index no longer uses optional=True.  Use contextlib.nullcontext()
when the input tensor is not on CUDA.

_infer_attn_params now returns softmax_scale; update test_standard_layout
to include it in the expected dict.

Signed-off-by: khazic <khazzz1c@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 7eabac1

Comment thread tools/debug_te_attention.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might want to remove this

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okk

Drop the standalone injection smoke-test script; the dispatch counters
are now exposed via ``get_te_attention_stats()`` for callers that need
them, and the training recipes already log these counters in their own
runs, so the dedicated debug entry point is redundant.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 2977f23

…ats API

Raise patch coverage above the 80% Codecov target by adding three new
test classes for previously-uncovered helpers on ``te_attention.py``:

- TestDetectCausalMask (13 tests): float/bfloat16/bool causal masks,
  sliding-window detection, and the None-return branches for unsupported
  dtype, wrong ndim, non-square mask, unmasked upper-right corner,
  masked diagonal, padded batch, unexpected first-row width, and
  sliding-window size mismatch.
- TestProjOutFeatures (5 tests): standard Linear, weight-only custom
  module, wrapped linear (Gemma4ClippableLinear layout), None input,
  and module without any of the expected attributes.
- TestStatsAPI (2 tests): reset/get round-trip and copy semantics so
  callers cannot mutate the internal counters.

All 20 new tests pass locally alongside the existing 22.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 1e5fffa

@HuiyingLi
Copy link
Copy Markdown
Contributor

/claude review

def _make_te_sdpa_with_mock(self, num_heads=8, num_kv_heads=4):
te_module = mock.MagicMock()
# TE returns [B, S, H, D]
te_module.return_value = torch.zeros(2, 10, num_heads, 64)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fallback_scale_mismatch path in te_sdpa is untested — it has non-trivial logic (one-shot warning via _SCALE_MISMATCH_WARNED, stats tracking, and the tolerance check abs(scale - softmax_scale) > 1e-6). Worth adding a test that passes a mismatched scale kwarg and asserts:

  1. original_sdpa is called (not te_module)
  2. get_te_attention_stats()["fallback_scale_mismatch"] increments

Note: _SCALE_MISMATCH_WARNED is a module-level global that isn't reset by reset_te_attention_stats(), so any test verifying the warning log would need to manually reset it to avoid ordering dependence.

----------------
- Only causal (``is_causal=True``) and no-mask attention are supported.
Non-trivial ``attn_mask`` tensors fall back to native SDPA with a debug log.
- Sliding-window attention is not yet handled (uses unbounded left window).
Copy link
Copy Markdown
Contributor

@athitten athitten Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@khazic curious whats the impact of this on gemma4 runs with TE attn ? since gemma4 uses sliding window attention

Address Claude's review comment (PR NVIDIA-NeMo#2011): the ``fallback_scale_mismatch``
path was uncovered. Add ``TestScaleMismatchFallback`` with five cases:

- Mismatch triggers fallback to ``original_sdpa`` and increments the
  ``fallback_scale_mismatch`` counter; TE module is not called.
- Matching scale uses the TE module (no fallback, counter unchanged).
- Float drift below the 1e-6 tolerance is accepted.
- ``scale=None`` (PyTorch default) is tolerated.
- Warning emission is one-shot (``_SCALE_MISMATCH_WARNED`` flag set on
  first mismatch, subsequent mismatches still fall back but do not
  re-warn). Tests reset the flag explicitly so ordering is independent.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 8959f98

The original "Limitations (v1)" block was written before the
mask-detection and sliding-window fixes landed. Update the header to
describe the current behavior:

- mask handling: causal / no_mask (attn_mask=None) and runtime detection
  of HF's canonical causal / causal+sliding masks via
  ``_detect_causal_mask``, with fallback to native SDPA for non-canonical
  patterns.
- sliding window: read from ``module.sliding_window`` at injection time
  and converted to TE's ``(window_size[0], 0)`` convention, applied on
  both the mask-detected and ``attn_mask=None`` paths.
- remaining limitation: modules that import SDPA as a local symbol rather
  than reading ``torch.nn.functional.scaled_dot_product_attention`` at
  call time won't pick up the runtime patch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test fe93303

@@ -0,0 +1,582 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants