fix: try fix dpa4 compile#5483
Conversation
for more information, see https://pre-commit.ci
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
| val = getattr(fitting, aname, None) | ||
| if val is not None and torch.is_tensor(val): | ||
| names.append(_FIT_ATTR_PREFIX + aname) | ||
| except AttributeError: |
| names.append(_FIT_ATTR_PREFIX + aname) | ||
| except AttributeError: | ||
| pass | ||
| except AttributeError: |
There was a problem hiding this comment.
Pull request overview
This PR attempts to improve/repair the PyTorch-compiled execution path for the SeZM/DPA4 model, primarily by reducing recompiles/OOM in multi-task setups and addressing symbolic-shape tracing issues in make_fx.
Changes:
- Add module-level compile sharing and promote selected per-task buffers (e.g.,
out_bias,bias_atom_e,case_embd) as FX inputs to enable compiled-graph reuse across shared-parameter tasks. - Add additional symbolic-shape anti-aliasing logic for trace inputs and temporarily disable
ShapeEnvduck sizing during tracing. - Change edge-list construction to append a single masked dummy edge (instead of two) and adjust related documentation/behavior.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| aparam: torch.Tensor | None = None, | ||
| charge_spin: torch.Tensor | None = None, | ||
| *, | ||
| do_atomic_virial: bool = False, | ||
| charge_spin: torch.Tensor | None = None, | ||
| ) -> torch.nn.Module: |
| # === Step 3. Compact edges + append one masked dummy === | ||
| # NOTE: Always append exactly one masked dummy edge. | ||
| # ``torch.nonzero(edge_mask_actual)`` produces a data-dependent | ||
| # number of valid edges, which can be zero on sparse or | ||
| # single-type systems. make_fx cannot trace an |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5483 +/- ##
==========================================
- Coverage 81.34% 81.34% -0.01%
==========================================
Files 868 868
Lines 96373 96583 +210
Branches 4233 4233
==========================================
+ Hits 78399 78569 +170
- Misses 16675 16715 +40
Partials 1299 1299 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
No description provided.