Skip to content

refactor(pt): full refactor of HybridMuon optimizer#5275

Draft
OutisLi wants to merge 2 commits intodeepmodeling:masterfrom
OutisLi:pr/muon
Draft

refactor(pt): full refactor of HybridMuon optimizer#5275
OutisLi wants to merge 2 commits intodeepmodeling:masterfrom
OutisLi:pr/muon

Conversation

@OutisLi
Copy link
Collaborator

@OutisLi OutisLi commented Mar 1, 2026

  1. refactor name-based routing
  2. add slice mode for HybridMuon opt
  3. add Magma-lite damping for Muon path

Summary by CodeRabbit

Release Notes

  • New Features

    • Updated HybridMuon optimizer with new routing modes (slice, 2d, flat) controlled via muon_mode parameter.
    • Added magma_muon option to enable Magma-lite damping for improved optimizer stability.
    • Enhanced parameter routing with name-aware detection for bias and Adam variants.
  • Documentation

    • Updated optimizer configuration documentation to reflect new routing behavior and options.
  • Tests

    • Added comprehensive test coverage for new routing modes and damping functionality.

OutisLi added 2 commits March 1, 2026 10:21
- Implement block-wise momentum-gradient alignment with EMA smoothing
  and soft scaling [0.1, 1.0] on Muon updates (magma_muon option)
- Fix AdamW weight decay to use adam_lr instead of base lr
- Wire magma_muon through training config and argcheck
- Clean up redundant optimizer tests
Copilot AI review requested due to automatic review settings March 1, 2026 02:25
@github-actions github-actions bot added the Python label Mar 1, 2026
@dosubot dosubot bot added the enhancement label Mar 1, 2026
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 1, 2026

📝 Walkthrough

Walkthrough

This PR refactors the HybridMuonOptimizer's parameter routing system, replacing the previous muon_2d_only/min_2d_dim binary flags with a name-aware, mode-driven scheme (muon_mode: "2d", "flat", or "slice"). It introduces Magma-lite damping for Muon updates, adds batched orthogonalization for slice-mode processing, and updates configuration and integration points accordingly.

Changes

Cohort / File(s) Summary
Hybrid Muon Core Implementation
deepmd/pt/optimizer/hybrid_muon.py
Completely reworked routing policy from 1D/2D dichotomy to name-aware and mode-driven routing. Replaced muon_2d_only/min_2d_dim with muon_mode (2d/flat/slice) and named_parameters. Added Magma-lite damping logic with new _compute_magma_scale and _compute_magma_scales_for_bucket methods. Introduced _batched_newton_schulz_orth for slice-mode batched orthogonalization. Added helper functions get_adam_route, get_effective_shape, and get_matrix_view_shape for routing decisions. Updated init signature and step type hints.
Configuration & Argument Validation
deepmd/utils/argcheck.py
Updated HybridMuon argument schema: removed muon_2d_only and min_2d_dim parameters, added muon_mode string (default "slice") and magma_muon boolean (default False). Expanded documentation to explain name-based Adam routing and muon_mode routing behavior (matrix vs. slice vs. flat).
Training Integration
deepmd/pt/train/training.py
Updated optimizer configuration to propagate new muon_mode and magma_muon parameters. Modified HybridMuonOptimizer instantiation to accept named_parameters and removed old parameter passing for muon_2d_only and min_2d_dim.
Test Suite
source/tests/pt/test_hybrid_muon.py
Added MAGMA_MIN_SCALE import. Removed obsolete tests (test_shape_and_dtype, test_step, test_muon_adam_fallback_small_2d, test_flash_muon_*). Added comprehensive new test coverage for slice/2D/flat mode routing, Magma-lite damping validation, 3D weight routing behavior, and state dict compatibility.

Sequence Diagram

sequenceDiagram
    actor User
    participant HybridMuonOptimizer
    participant RoutingLogic as Routing Logic<br/>(Name-based)
    participant ShapeAnalysis as Shape Analysis<br/>(Mode-aware)
    participant MagmaScaler as Magma Scaler<br/>(if enabled)
    participant OptimizerStep as Optimizer Step<br/>(Muon/Adam/AdamW)

    User->>HybridMuonOptimizer: step(closure)
    HybridMuonOptimizer->>RoutingLogic: Evaluate param name<br/>(bias, adam_, adamw_?)
    RoutingLogic-->>HybridMuonOptimizer: Route: Adam/AdamW/Muon
    
    alt Muon Route
        HybridMuonOptimizer->>ShapeAnalysis: Get matrix view shape<br/>(muon_mode: 2d/flat/slice)
        ShapeAnalysis-->>HybridMuonOptimizer: Effective shape & view dims
        
        alt magma_muon enabled
            HybridMuonOptimizer->>MagmaScaler: Compute damping scales<br/>(per-param EMA scoring)
            MagmaScaler-->>HybridMuonOptimizer: Damping scale factors
        end
        
        HybridMuonOptimizer->>OptimizerStep: Apply Newton-Schulz orth<br/>(batched if slice mode)
        OptimizerStep->>OptimizerStep: Apply Muon update<br/>with optional damping
    else Adam/AdamW Route
        HybridMuonOptimizer->>OptimizerStep: Apply Adam/AdamW update<br/>(momentum or EMA)
    end
    
    OptimizerStep-->>HybridMuonOptimizer: Updated state
    HybridMuonOptimizer-->>User: Loss (from closure)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

enhancement, new feature, Python

Suggested reviewers

  • wanghan-iapcm
  • njzjz
  • iProzd
🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 59.46% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: a comprehensive refactor of the HybridMuon optimizer that introduces name-based routing, slice mode, and Magma-lite damping—all central to this PR.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
deepmd/utils/argcheck.py (1)

3758-3768: Validate muon_mode values at argcheck time.

muon_mode is free-form str here, so typos pass schema normalization and fail later during optimizer construction. Consider constraining accepted values to {"2d", "flat", "slice"} in this layer for earlier, clearer errors.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/utils/argcheck.py` around lines 3758 - 3768, The muon_mode argument in
the argcheck schema (the "muon_mode" param definition) is currently an
unconstrained str which lets typos slip through; change the schema to restrict
allowed values to the set {"2d", "flat", "slice"} (e.g. use an enum/choices
validator or an explicit check) so validation fails early with a clear message
referencing muon_mode when an invalid value is provided.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@deepmd/pt/optimizer/hybrid_muon.py`:
- Around line 343-345: Define specific exception classes (e.g.,
InvalidTensorShapeError(ValueError) and InvalidMuonModeError(ValueError)) near
the top of the module with the full explanatory messages as their default
docstring/message, then replace the three inline multi-line ValueError raises
with simple raises of those classes: replace the shape check in
batched_newton_schulz (the current raise ValueError(... "Batched Newton-Schulz
expects a 3D tensor...")) with raise InvalidTensorShapeError, and replace both
muon_mode validation raises (the f-string multi-line and the single-line check)
with raise InvalidMuonModeError; run ruff check/format before committing.

In `@source/tests/pt/test_hybrid_muon.py`:
- Around line 358-361: The assertion comparing optimizer state uses exact float
equality which is fragile on CUDA; update the torch.allclose call for
model1.adam_scale vs model2.adam_scale to use a small nonzero tolerance (e.g.
atol=1e-6 and/or rtol=1e-6) instead of atol=0.0, rtol=0.0 so the test checks
near-equality while remaining stable; locate the comparison around
model1.adam_scale in the test_hybrid_muon.py and replace the zero tolerances
with a tight positive tolerance.

---

Nitpick comments:
In `@deepmd/utils/argcheck.py`:
- Around line 3758-3768: The muon_mode argument in the argcheck schema (the
"muon_mode" param definition) is currently an unconstrained str which lets typos
slip through; change the schema to restrict allowed values to the set {"2d",
"flat", "slice"} (e.g. use an enum/choices validator or an explicit check) so
validation fails early with a clear message referencing muon_mode when an
invalid value is provided.

ℹ️ Review info

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f0a966b and 52b027f.

📒 Files selected for processing (4)
  • deepmd/pt/optimizer/hybrid_muon.py
  • deepmd/pt/train/training.py
  • deepmd/utils/argcheck.py
  • source/tests/pt/test_hybrid_muon.py

Comment on lines +343 to +345
raise ValueError(
"Batched Newton-Schulz expects a 3D tensor with shape (B, m, n)."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, check if the file exists and get line count
wc -l deepmd/pt/optimizer/hybrid_muon.py

Repository: deepmodeling/deepmd-kit

Length of output: 106


🏁 Script executed:

# Extract the specific lines mentioned in the review comment
sed -n '343,345p; 477,479p; 604p' deepmd/pt/optimizer/hybrid_muon.py | cat -n

Repository: deepmodeling/deepmd-kit

Length of output: 428


🏁 Script executed:

# Run ruff check to detect TRY003 violations in the file
ruff check deepmd/pt/optimizer/hybrid_muon.py --select TRY003

Repository: deepmodeling/deepmd-kit

Length of output: 1997


🏁 Script executed:

# Get more context around the mentioned lines
echo "=== Lines 343-345 ===" && sed -n '341,347p' deepmd/pt/optimizer/hybrid_muon.py | cat -n
echo "=== Lines 477-479 ===" && sed -n '475,481p' deepmd/pt/optimizer/hybrid_muon.py | cat -n
echo "=== Lines 604 ===" && sed -n '602,606p' deepmd/pt/optimizer/hybrid_muon.py | cat -n

Repository: deepmodeling/deepmd-kit

Length of output: 925


Resolve TRY003 linting warnings for ValueError messages.

Three instances violate Ruff's TRY003 rule (avoid specifying long messages outside the exception class):

  • Lines 343-345: Multi-line error message for tensor shape validation
  • Lines 477-479: Multi-line error message with f-string for muon_mode validation
  • Line 604: Single-line error message for muon_mode validation

Consolidate messages into exception class definitions or use dedicated exception types to keep the file lint-clean. Per coding guidelines, run ruff check . and ruff format . before committing.

🧰 Tools
🪛 Ruff (0.15.2)

[warning] 343-345: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/optimizer/hybrid_muon.py` around lines 343 - 345, Define specific
exception classes (e.g., InvalidTensorShapeError(ValueError) and
InvalidMuonModeError(ValueError)) near the top of the module with the full
explanatory messages as their default docstring/message, then replace the three
inline multi-line ValueError raises with simple raises of those classes: replace
the shape check in batched_newton_schulz (the current raise ValueError(...
"Batched Newton-Schulz expects a 3D tensor...")) with raise
InvalidTensorShapeError, and replace both muon_mode validation raises (the
f-string multi-line and the single-line check) with raise InvalidMuonModeError;
run ruff check/format before committing.

Comment on lines +358 to +361
self.assertFalse(torch.allclose(model1.weight, model2.weight))
self.assertTrue(
torch.allclose(model1.adam_scale, model2.adam_scale, atol=0.0, rtol=0.0)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Avoid exact float equality for Adam-path invariance assertions.

Using atol=0.0, rtol=0.0 can be flaky on CUDA due tiny nondeterministic differences. A tight tolerance keeps intent while improving stability.

💡 Suggested test tweak
-        self.assertTrue(
-            torch.allclose(model1.adam_scale, model2.adam_scale, atol=0.0, rtol=0.0)
-        )
+        self.assertTrue(
+            torch.allclose(model1.adam_scale, model2.adam_scale, atol=1e-7, rtol=1e-6)
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt/test_hybrid_muon.py` around lines 358 - 361, The assertion
comparing optimizer state uses exact float equality which is fragile on CUDA;
update the torch.allclose call for model1.adam_scale vs model2.adam_scale to use
a small nonzero tolerance (e.g. atol=1e-6 and/or rtol=1e-6) instead of atol=0.0,
rtol=0.0 so the test checks near-equality while remaining stable; locate the
comparison around model1.adam_scale in the test_hybrid_muon.py and replace the
zero tolerances with a tight positive tolerance.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the PyTorch HybridMuonOptimizer to use name-based routing, adds a new muon_mode routing scheme (including per-slice Muon for higher-rank tensors), and introduces optional “Magma-lite” damping applied only on the Muon update path. It also updates training/config plumbing and expands tests to cover the new routing and damping behavior.

Changes:

  • Replace muon_2d_only / min_2d_dim routing with muon_mode (2d / flat / slice) and parameter-name-based routing rules.
  • Add magma_muon option implementing per-block momentum/gradient alignment scoring and damping on Muon updates.
  • Update training arg schema + trainer optimizer construction; expand unit tests for slice-mode routing and Magma damping.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
deepmd/pt/optimizer/hybrid_muon.py Implements muon_mode routing, name-based Adam/AdamW routing, batched NS for slice mode, and Magma-lite damping.
deepmd/pt/train/training.py Wires new optimizer args (muon_mode, magma_muon) and passes named parameters for name-based routing.
deepmd/utils/argcheck.py Updates the training config schema/docs for HybridMuon to use muon_mode and adds magma_muon.
source/tests/pt/test_hybrid_muon.py Removes outdated tests and adds new coverage for slice routing, 2d routing behavior, and Magma damping state/range.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

muon_2d_only=bool(self.opt_param["muon_2d_only"]),
min_2d_dim=int(self.opt_param["min_2d_dim"]),
muon_mode=str(self.opt_param["muon_mode"]),
named_parameters=tuple(self.wrapper.named_parameters()),
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

named_parameters=tuple(self.wrapper.named_parameters()) eagerly materializes all (name, param) pairs, which can be expensive in memory/time for large models. Since HybridMuonOptimizer only needs to iterate once to build an id->name map, pass self.wrapper.named_parameters() directly (or another lazy iterable) instead of converting to a tuple.

Suggested change
named_parameters=tuple(self.wrapper.named_parameters()),
named_parameters=self.wrapper.named_parameters(),

Copilot uses AI. Check for mistakes.
- Parameters are routed by effective shape (singleton dimensions removed).
- ``muon_mode="2d"``:
- effective rank 2 parameters use Muon.
- effective rank >2 parameters use Adam.
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the HybridMuonOptimizer docstring, muon_mode="2d" currently says effective-rank >2 parameters use plain Adam, but _build_param_routing() routes these to the decoupled-decay AdamW-style path (adam_decay). Please update the docstring to match the actual behavior (Adam + decoupled weight decay for non-matrix shapes in 2d mode).

Suggested change
- effective rank >2 parameters use Adam.
- effective rank >2 parameters use Adam with decoupled weight decay
(AdamW-style) fallback.

Copilot uses AI. Check for mistakes.
.. [4] Flash-Muon: Triton-accelerated symmetric matmul for Newton-Schulz.
https://github.com/lintianyang/flash-muon (MIT License, Tianyang Lin)
.. [5] Magma: Momentum-Aligned Gradient Masking for Stable Optimizer Updates.
arXiv:2602.15322, 2025.
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reference arXiv:2602.15322, 2025 is internally inconsistent: arXiv IDs starting with 2602 correspond to Feb 2026. Please correct the year in the citation (or adjust the identifier) so the reference is accurate.

Suggested change
arXiv:2602.15322, 2025.
arXiv:2602.15322, 2026.

Copilot uses AI. Check for mistakes.
@codecov
Copy link

codecov bot commented Mar 1, 2026

Codecov Report

❌ Patch coverage is 79.41176% with 42 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.96%. Comparing base (f0a966b) to head (52b027f).

Files with missing lines Patch % Lines
deepmd/pt/optimizer/hybrid_muon.py 79.41% 42 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5275      +/-   ##
==========================================
+ Coverage   81.94%   81.96%   +0.01%     
==========================================
  Files         750      750              
  Lines       75456    75536      +80     
  Branches     3648     3648              
==========================================
+ Hits        61831    61911      +80     
+ Misses      12457    12456       -1     
- Partials     1168     1169       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@OutisLi OutisLi marked this pull request as draft March 1, 2026 03:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants