Draft
Conversation
…corder Reduces the slurm --time for the gpt_oss_120b SFT CI recipe from 1h to 20m (the failing run on pipeline 49171565 used 18m17s of walltime before SIGABRT). Adds env_vars so a future hang dumps the collective stack trace instead of the current "Stack trace of the failed collective not found, potentially because FlightRecorder is disabled" message: TORCH_NCCL_TRACE_BUFFER_SIZE: 20000 TORCH_NCCL_DUMP_ON_TIMEOUT: 1 Exercises the per-recipe env_vars path introduced in #1999.
Drops the other 38 recipes from test_recipes.yml for this branch so the nemo-ci pipeline only spawns the single failing test we're debugging. This commit exists solely for iterating on the NCCL FlightRecorder dump; drop it before merging anything back to main.
_clip_grad_norm_impl delegated per-sharding-group norms to torch.nn.utils.get_total_norm, which for DTensor grads stacks the per-param scalar norms into a 1-D DTensor whose local length equals the number of local param tensors in the group. Under expert parallelism, ranks hold different numbers of expert param tensors in a group, so torch.linalg. vector_norm's redistribute (Partial -> Replicate) fires an allreduce with mismatched numel across ranks and hangs until the NCCL watchdog kills the job. FlightRecorder backtrace from jobs/303663415: _reduce_value (tensor/_ops/_math_ops.py:134) redistribute_local_tensor (tensor/_redistribute.py:906) _dispatch_get_local_results_slow_path (tensor/_dispatch.py:261) _get_total_norm (nn/utils/clip_grad.py:106) _clip_grad_norm_impl (utils.py:110) Replace with a scalar-first reduction: sum(|g_local|^p) as a scalar on each rank, then one numel=1 allreduce per Shard mesh dim. Math is identical for Shard placements (sum_over_shards(|g_local|^p) == |g_full|^p). Partial placements still fall back to per-grad full_tensor() since they can't be reduced from local squares alone; model param grads in FSDP2+EP don't hit this path. Fixes gpt_oss_120b SFT on 4x8 H100 (ep_size=32).
Diagnostic log to confirm whether expert stacked tensors silently fail to load from HF MXFP4 checkpoints under GPTOSSStateDictAdapter's to_hf/from_hf blocks+scales dance. If experts are at random init, their per-param RMS will be ~0.13 (Kaiming-ish init scale) while non-experts will be ~0.01 (pretrained magnitude). Revert once the weight-loading path is fixed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do ?
Pipeline: https://gitlab-master.nvidia.com/dl/JoC/nemo-ci/-/jobs/303680220
Changelog
Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information