refactor(pt): full refactor of HybridMuon optimizer#5275
refactor(pt): full refactor of HybridMuon optimizer#5275OutisLi wants to merge 2 commits intodeepmodeling:masterfrom
Conversation
- Implement block-wise momentum-gradient alignment with EMA smoothing and soft scaling [0.1, 1.0] on Muon updates (magma_muon option) - Fix AdamW weight decay to use adam_lr instead of base lr - Wire magma_muon through training config and argcheck - Clean up redundant optimizer tests
📝 WalkthroughWalkthroughThis PR refactors the HybridMuonOptimizer's parameter routing system, replacing the previous muon_2d_only/min_2d_dim binary flags with a name-aware, mode-driven scheme (muon_mode: "2d", "flat", or "slice"). It introduces Magma-lite damping for Muon updates, adds batched orthogonalization for slice-mode processing, and updates configuration and integration points accordingly. Changes
Sequence DiagramsequenceDiagram
actor User
participant HybridMuonOptimizer
participant RoutingLogic as Routing Logic<br/>(Name-based)
participant ShapeAnalysis as Shape Analysis<br/>(Mode-aware)
participant MagmaScaler as Magma Scaler<br/>(if enabled)
participant OptimizerStep as Optimizer Step<br/>(Muon/Adam/AdamW)
User->>HybridMuonOptimizer: step(closure)
HybridMuonOptimizer->>RoutingLogic: Evaluate param name<br/>(bias, adam_, adamw_?)
RoutingLogic-->>HybridMuonOptimizer: Route: Adam/AdamW/Muon
alt Muon Route
HybridMuonOptimizer->>ShapeAnalysis: Get matrix view shape<br/>(muon_mode: 2d/flat/slice)
ShapeAnalysis-->>HybridMuonOptimizer: Effective shape & view dims
alt magma_muon enabled
HybridMuonOptimizer->>MagmaScaler: Compute damping scales<br/>(per-param EMA scoring)
MagmaScaler-->>HybridMuonOptimizer: Damping scale factors
end
HybridMuonOptimizer->>OptimizerStep: Apply Newton-Schulz orth<br/>(batched if slice mode)
OptimizerStep->>OptimizerStep: Apply Muon update<br/>with optional damping
else Adam/AdamW Route
HybridMuonOptimizer->>OptimizerStep: Apply Adam/AdamW update<br/>(momentum or EMA)
end
OptimizerStep-->>HybridMuonOptimizer: Updated state
HybridMuonOptimizer-->>User: Loss (from closure)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
deepmd/utils/argcheck.py (1)
3758-3768: Validatemuon_modevalues at argcheck time.
muon_modeis free-formstrhere, so typos pass schema normalization and fail later during optimizer construction. Consider constraining accepted values to{"2d", "flat", "slice"}in this layer for earlier, clearer errors.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/utils/argcheck.py` around lines 3758 - 3768, The muon_mode argument in the argcheck schema (the "muon_mode" param definition) is currently an unconstrained str which lets typos slip through; change the schema to restrict allowed values to the set {"2d", "flat", "slice"} (e.g. use an enum/choices validator or an explicit check) so validation fails early with a clear message referencing muon_mode when an invalid value is provided.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt/optimizer/hybrid_muon.py`:
- Around line 343-345: Define specific exception classes (e.g.,
InvalidTensorShapeError(ValueError) and InvalidMuonModeError(ValueError)) near
the top of the module with the full explanatory messages as their default
docstring/message, then replace the three inline multi-line ValueError raises
with simple raises of those classes: replace the shape check in
batched_newton_schulz (the current raise ValueError(... "Batched Newton-Schulz
expects a 3D tensor...")) with raise InvalidTensorShapeError, and replace both
muon_mode validation raises (the f-string multi-line and the single-line check)
with raise InvalidMuonModeError; run ruff check/format before committing.
In `@source/tests/pt/test_hybrid_muon.py`:
- Around line 358-361: The assertion comparing optimizer state uses exact float
equality which is fragile on CUDA; update the torch.allclose call for
model1.adam_scale vs model2.adam_scale to use a small nonzero tolerance (e.g.
atol=1e-6 and/or rtol=1e-6) instead of atol=0.0, rtol=0.0 so the test checks
near-equality while remaining stable; locate the comparison around
model1.adam_scale in the test_hybrid_muon.py and replace the zero tolerances
with a tight positive tolerance.
---
Nitpick comments:
In `@deepmd/utils/argcheck.py`:
- Around line 3758-3768: The muon_mode argument in the argcheck schema (the
"muon_mode" param definition) is currently an unconstrained str which lets typos
slip through; change the schema to restrict allowed values to the set {"2d",
"flat", "slice"} (e.g. use an enum/choices validator or an explicit check) so
validation fails early with a clear message referencing muon_mode when an
invalid value is provided.
ℹ️ Review info
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
deepmd/pt/optimizer/hybrid_muon.pydeepmd/pt/train/training.pydeepmd/utils/argcheck.pysource/tests/pt/test_hybrid_muon.py
| raise ValueError( | ||
| "Batched Newton-Schulz expects a 3D tensor with shape (B, m, n)." | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, check if the file exists and get line count
wc -l deepmd/pt/optimizer/hybrid_muon.pyRepository: deepmodeling/deepmd-kit
Length of output: 106
🏁 Script executed:
# Extract the specific lines mentioned in the review comment
sed -n '343,345p; 477,479p; 604p' deepmd/pt/optimizer/hybrid_muon.py | cat -nRepository: deepmodeling/deepmd-kit
Length of output: 428
🏁 Script executed:
# Run ruff check to detect TRY003 violations in the file
ruff check deepmd/pt/optimizer/hybrid_muon.py --select TRY003Repository: deepmodeling/deepmd-kit
Length of output: 1997
🏁 Script executed:
# Get more context around the mentioned lines
echo "=== Lines 343-345 ===" && sed -n '341,347p' deepmd/pt/optimizer/hybrid_muon.py | cat -n
echo "=== Lines 477-479 ===" && sed -n '475,481p' deepmd/pt/optimizer/hybrid_muon.py | cat -n
echo "=== Lines 604 ===" && sed -n '602,606p' deepmd/pt/optimizer/hybrid_muon.py | cat -nRepository: deepmodeling/deepmd-kit
Length of output: 925
Resolve TRY003 linting warnings for ValueError messages.
Three instances violate Ruff's TRY003 rule (avoid specifying long messages outside the exception class):
- Lines 343-345: Multi-line error message for tensor shape validation
- Lines 477-479: Multi-line error message with f-string for muon_mode validation
- Line 604: Single-line error message for muon_mode validation
Consolidate messages into exception class definitions or use dedicated exception types to keep the file lint-clean. Per coding guidelines, run ruff check . and ruff format . before committing.
🧰 Tools
🪛 Ruff (0.15.2)
[warning] 343-345: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@deepmd/pt/optimizer/hybrid_muon.py` around lines 343 - 345, Define specific
exception classes (e.g., InvalidTensorShapeError(ValueError) and
InvalidMuonModeError(ValueError)) near the top of the module with the full
explanatory messages as their default docstring/message, then replace the three
inline multi-line ValueError raises with simple raises of those classes: replace
the shape check in batched_newton_schulz (the current raise ValueError(...
"Batched Newton-Schulz expects a 3D tensor...")) with raise
InvalidTensorShapeError, and replace both muon_mode validation raises (the
f-string multi-line and the single-line check) with raise InvalidMuonModeError;
run ruff check/format before committing.
| self.assertFalse(torch.allclose(model1.weight, model2.weight)) | ||
| self.assertTrue( | ||
| torch.allclose(model1.adam_scale, model2.adam_scale, atol=0.0, rtol=0.0) | ||
| ) |
There was a problem hiding this comment.
Avoid exact float equality for Adam-path invariance assertions.
Using atol=0.0, rtol=0.0 can be flaky on CUDA due tiny nondeterministic differences. A tight tolerance keeps intent while improving stability.
💡 Suggested test tweak
- self.assertTrue(
- torch.allclose(model1.adam_scale, model2.adam_scale, atol=0.0, rtol=0.0)
- )
+ self.assertTrue(
+ torch.allclose(model1.adam_scale, model2.adam_scale, atol=1e-7, rtol=1e-6)
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@source/tests/pt/test_hybrid_muon.py` around lines 358 - 361, The assertion
comparing optimizer state uses exact float equality which is fragile on CUDA;
update the torch.allclose call for model1.adam_scale vs model2.adam_scale to use
a small nonzero tolerance (e.g. atol=1e-6 and/or rtol=1e-6) instead of atol=0.0,
rtol=0.0 so the test checks near-equality while remaining stable; locate the
comparison around model1.adam_scale in the test_hybrid_muon.py and replace the
zero tolerances with a tight positive tolerance.
There was a problem hiding this comment.
Pull request overview
This PR refactors the PyTorch HybridMuonOptimizer to use name-based routing, adds a new muon_mode routing scheme (including per-slice Muon for higher-rank tensors), and introduces optional “Magma-lite” damping applied only on the Muon update path. It also updates training/config plumbing and expands tests to cover the new routing and damping behavior.
Changes:
- Replace
muon_2d_only/min_2d_dimrouting withmuon_mode(2d/flat/slice) and parameter-name-based routing rules. - Add
magma_muonoption implementing per-block momentum/gradient alignment scoring and damping on Muon updates. - Update training arg schema + trainer optimizer construction; expand unit tests for slice-mode routing and Magma damping.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
deepmd/pt/optimizer/hybrid_muon.py |
Implements muon_mode routing, name-based Adam/AdamW routing, batched NS for slice mode, and Magma-lite damping. |
deepmd/pt/train/training.py |
Wires new optimizer args (muon_mode, magma_muon) and passes named parameters for name-based routing. |
deepmd/utils/argcheck.py |
Updates the training config schema/docs for HybridMuon to use muon_mode and adds magma_muon. |
source/tests/pt/test_hybrid_muon.py |
Removes outdated tests and adds new coverage for slice routing, 2d routing behavior, and Magma damping state/range. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| muon_2d_only=bool(self.opt_param["muon_2d_only"]), | ||
| min_2d_dim=int(self.opt_param["min_2d_dim"]), | ||
| muon_mode=str(self.opt_param["muon_mode"]), | ||
| named_parameters=tuple(self.wrapper.named_parameters()), |
There was a problem hiding this comment.
named_parameters=tuple(self.wrapper.named_parameters()) eagerly materializes all (name, param) pairs, which can be expensive in memory/time for large models. Since HybridMuonOptimizer only needs to iterate once to build an id->name map, pass self.wrapper.named_parameters() directly (or another lazy iterable) instead of converting to a tuple.
| named_parameters=tuple(self.wrapper.named_parameters()), | |
| named_parameters=self.wrapper.named_parameters(), |
| - Parameters are routed by effective shape (singleton dimensions removed). | ||
| - ``muon_mode="2d"``: | ||
| - effective rank 2 parameters use Muon. | ||
| - effective rank >2 parameters use Adam. |
There was a problem hiding this comment.
In the HybridMuonOptimizer docstring, muon_mode="2d" currently says effective-rank >2 parameters use plain Adam, but _build_param_routing() routes these to the decoupled-decay AdamW-style path (adam_decay). Please update the docstring to match the actual behavior (Adam + decoupled weight decay for non-matrix shapes in 2d mode).
| - effective rank >2 parameters use Adam. | |
| - effective rank >2 parameters use Adam with decoupled weight decay | |
| (AdamW-style) fallback. |
| .. [4] Flash-Muon: Triton-accelerated symmetric matmul for Newton-Schulz. | ||
| https://github.com/lintianyang/flash-muon (MIT License, Tianyang Lin) | ||
| .. [5] Magma: Momentum-Aligned Gradient Masking for Stable Optimizer Updates. | ||
| arXiv:2602.15322, 2025. |
There was a problem hiding this comment.
The reference arXiv:2602.15322, 2025 is internally inconsistent: arXiv IDs starting with 2602 correspond to Feb 2026. Please correct the year in the citation (or adjust the identifier) so the reference is accurate.
| arXiv:2602.15322, 2025. | |
| arXiv:2602.15322, 2026. |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5275 +/- ##
==========================================
+ Coverage 81.94% 81.96% +0.01%
==========================================
Files 750 750
Lines 75456 75536 +80
Branches 3648 3648
==========================================
+ Hits 61831 61911 +80
+ Misses 12457 12456 -1
- Partials 1168 1169 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Summary by CodeRabbit
Release Notes
New Features
slice,2d,flat) controlled viamuon_modeparameter.magma_muonoption to enable Magma-lite damping for improved optimizer stability.Documentation
Tests