Skip to content

Commit 95002ad

Browse files
WCJ-BERTclaude
andcommitted
fix(AlgEngine): math-SDPA switch for H-series (Hopper) training crash
On H-series (Hopper, sm_90: H100/H200/H800/H20/GH200) GPUs, the attention-heavy DiffusionPlanningHead routes its nn.MultiheadAttention through torch 2.0.1's fused flash / mem-efficient SDPA kernels, which fault on the first training iteration with "CUDA error: an illegal instruction was encountered". The fault is reported asynchronously by the NCCL watchdog, so it masquerades as a distributed-comm/hardware error even though the real culprit is the attention kernel. Repros across machines (same conda env: torch 2.0.1, bundled NCCL 2.14.3); 1-2 GPU and the vadv2 head are unaffected. Add maybe_disable_efficient_sdp(), gated by env NAVFORMER_DISABLE_EFFICIENT_SDP=1, to force the math SDPA backend as a workaround. No-op by default, so other configs (vadv2 / hydramdp) are untouched unless the var is set. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent f254779 commit 95002ad

1 file changed

Lines changed: 17 additions & 0 deletions

File tree

projects/AlgEngine/scripts/train.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,22 @@
2424
warnings.filterwarnings("ignore")
2525

2626

27+
def maybe_disable_efficient_sdp():
28+
if os.environ.get("NAVFORMER_DISABLE_EFFICIENT_SDP", "0") != "1":
29+
return
30+
31+
cuda_backends = getattr(torch.backends, "cuda", None)
32+
if cuda_backends is None:
33+
return
34+
35+
if hasattr(cuda_backends, "enable_flash_sdp"):
36+
cuda_backends.enable_flash_sdp(False)
37+
if hasattr(cuda_backends, "enable_mem_efficient_sdp"):
38+
cuda_backends.enable_mem_efficient_sdp(False)
39+
if hasattr(cuda_backends, "enable_math_sdp"):
40+
cuda_backends.enable_math_sdp(True)
41+
print("NAVFORMER_DISABLE_EFFICIENT_SDP=1: using math SDPA backend")
42+
2743

2844
def parse_args():
2945
parser = argparse.ArgumentParser(description='Train a detector')
@@ -95,6 +111,7 @@ def parse_args():
95111

96112

97113
def main():
114+
maybe_disable_efficient_sdp()
98115
args = parse_args()
99116

100117
cfg = Config.fromfile(args.config)

0 commit comments

Comments
 (0)