Skip to content

Comments

feat(pt_expt): full model and refact the module output names of dpmodel backend#5243

Closed
wanghan-iapcm wants to merge 77 commits intodeepmodeling:masterfrom
wanghan-iapcm:feat-full-model
Closed

feat(pt_expt): full model and refact the module output names of dpmodel backend#5243
wanghan-iapcm wants to merge 77 commits intodeepmodeling:masterfrom
wanghan-iapcm:feat-full-model

Conversation

@wanghan-iapcm
Copy link
Collaborator

@wanghan-iapcm wanghan-iapcm commented Feb 14, 2026

Summary by CodeRabbit

  • New Features

    • Exposed model introspection (descriptor and output/bias accessors) and a PyTorch experimental energy model with traceable lower-level export and translated output mapping.
  • Improvements

    • Better device propagation for GPU/accelerator allocations, backend-agnostic input/output casting, and removal of in-place mutations for safer computation.
  • Refactor

    • Streamlined PyTorch module wrapping to decorator-based classes for cleaner runtime integration.
  • Tests

    • Added extensive autodiff and cross-backend tests for energy, force, and virial (including PT-Expt).

@dosubot dosubot bot added the new feature label Feb 14, 2026
Comment on lines +36 to +44
def forward(
self,
coord: torch.Tensor,
atype: torch.Tensor,
box: torch.Tensor | None = None,
fparam: torch.Tensor | None = None,
aparam: torch.Tensor | None = None,
do_atomic_virial: bool = False,
) -> dict[str, torch.Tensor]:

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning

This method requires at least 3 positional arguments, whereas overridden
Identity.forward
requires 2.
This method requires at least 3 positional arguments, whereas overridden test_torch_module_respects_explicit_forward.MockModule.forward requires 2.
return
return super().__setattr__(name, value)

def call(self, x: torch.Tensor) -> torch.Tensor:

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning

This method does not accept arbitrary keyword arguments, which overridden
NativeOP.call
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires 2 positional arguments, whereas overridden
NativeOP.call
may be called with 1.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires 2 positional arguments, whereas overridden
NativeOP.call
may be called with arbitrarily many.
This call
correctly calls the base method, but does not match the signature of the overriding method.
)
# Compare the common keys
common_keys = set(dp_ret.keys()) & set(pt_ret.keys())
self.assertTrue(len(common_keys) > 0)

Check notice

Code scanning / CodeQL

Imprecise assert Note test

assertTrue(a > b) cannot provide an informative message. Using assertGreater(a, b) instead will give more informative messages.
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the full model support for the experimental PyTorch backend (pt_expt), including hooks for descriptor and fitting layer evaluation, and autograd-based derivative calculations. The changes also improve backend-agnostic array operations in the core dpmodel and ensure compatibility with torch.fx tracing. I have identified a few issues regarding potential runtime errors in the new evaluation hooks and unintended side effects in the output definition translation logic.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 14, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds PT-Expt PyTorch-exportable model layers and autodiff-based output transformations, introduces descriptor and bias accessors, refactors input/output casting across backends, enables middle-output capture in fitting, updates device-aware tensor creation, replaces in-place network ops, and adds extensive pt_expt tests.

Changes

Cohort / File(s) Summary
Fitting output path
deepmd/dpmodel/fitting/general_fitting.py
Always return a results dict keyed by the model var name; introduce local results dict and ensure middle-output is stored under result mapping for mixed and non-mixed flows.
Model accessors & energy mapping (dpmodel)
deepmd/dpmodel/model/dp_model.py, deepmd/dpmodel/model/ener_model.py
Added get_descriptor() accessor in DPModelCommon; added translated_output_def() in dpmodel energy model to map internal names to user-facing outputs.
Type-cast refactor (dpmodel core)
deepmd/dpmodel/model/make_model.py
Renamed public casting helpers to internal _input_type_cast/_output_type_cast, generalized dtype handling/return types, added out-bias accessors (get_out_bias, set_out_bias, change_out_bias) and updated call wiring.
Device-aware output transforms (dpmodel)
deepmd/dpmodel/model/transform_output.py
Propagate device when allocating arrays (zeros/virial/hessian) to ensure tensors are created on the mapping/device.
In-place arithmetic removal (dpmodel utils)
deepmd/dpmodel/utils/network.py
Replaced in-place operators with explicit arithmetic assignments in network forward logic.
PT-Expt module decorator migration
deepmd/pt_expt/atomic_model/dp_atomic_model.py, deepmd/pt_expt/fitting/ener_fitting.py, deepmd/pt_expt/fitting/invar_fitting.py
Replace manual torch.nn.Module plumbing with @torch_module decorator; remove custom init/call/setattr and dpmodel_setattr usages.
PT-Expt model infrastructure & EnergyModel
deepmd/pt_expt/model/__init__.py, deepmd/pt_expt/model/make_model.py, deepmd/pt_expt/model/ener_model.py
Add pt_expt make_model factory and a new EnergyModel class exposing forward / forward_lower and traceable lower-path exports; wire DPModelCommon integration for PyTorch.
PT-Expt transform & autodiff utilities
deepmd/pt_expt/model/transform_output.py
Add atomic_virial_corr, task_deriv_one, take_deriv, fit_output_to_model_output and helpers to convert fitting-network outputs into model outputs via torch.autograd, handling forces, virials, masking, and atomic decomposition.
PT-Expt network/tracing compatibility
deepmd/pt_expt/utils/network.py
Adjust parameter wrapping and add NativeLayer.call plus _torch_activation to avoid make_fx proxy-tracing issues; ensure parameters register appropriately.
PD/PT small wiring changes
deepmd/pd/model/model/make_model.py, deepmd/pt/model/model/make_model.py
Rename input/output cast helpers to underscored variants and update call sites to match dpmodel internal convention.
Atomic model runtime checks
deepmd/pt/model/atomic_model/dp_atomic_model.py
Add runtime validation to raise clear errors if eval_descriptor or eval_fitting_last_layer caches are empty when queried.
Tests & test infra additions
source/tests/... (multiple files)
Add PT-Expt test support and many new tests: PT-Expt energy model unit tests, autodiff finite-difference force/virial tests, integration with existing cross-backend Ener tests, and eval_pt_expt_model helper in common test utilities.

Sequence Diagram(s)

sequenceDiagram
    participant Client as Client Code
    participant EM as EnergyModel (PT-Expt)
    participant Desc as Descriptor
    participant Fit as Fitting Net
    participant TO as TransformOutput
    participant Autograd as PyTorch Autograd

    Client->>EM: forward(coord, atype, box, ...)
    EM->>Desc: compute descriptors
    Desc-->>EM: descriptor tensor
    EM->>Fit: evaluate fitting network
    Fit-->>EM: fitting outputs (per-atom/reducible)
    EM->>TO: fit_output_to_model_output(fit_ret, coord_ext, ...)
    TO->>Autograd: enable grad on extended coords
    Autograd->>TO: compute ∇energy -> forces, virial, atom_virial
    TO-->>EM: model outputs (energy, atom_energy, force, virial, ...)
    EM-->>Client: return outputs
Loading
sequenceDiagram
    participant GP as GeneralFitting
    participant PT as Per-type Net(s)
    participant Acc as MiddleOutput Accumulator
    participant Res as Results Dict

    GP->>GP: eval_return_middle_output?
    alt mixed_types False
        GP->>Acc: init per-type accumulation
        loop per-type
            GP->>PT: evaluate net(xx)
            PT-->>GP: output (+ middle_output)
            GP->>Acc: accumulate middle_output
        end
        GP->>Res: store Acc as "middle_output"
    else mixed_types True
        GP->>PT: call_until_last(xx)
        PT-->>GP: middle_output
        GP->>Res: store middle_output
    end
    GP-->>Res: set var_name -> output tensor
    GP-->>Client: return Res
Loading

(Note: colored rectangles not used; flows kept minimal.)

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested reviewers

  • iProzd
  • njzjz
🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 64.18% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Merge Conflict Detection ✅ Passed ✅ No merge conflicts detected when merging into master
Title check ✅ Passed The title accurately describes the main changes: introducing full model support for pt_expt backend and refactoring dpmodel backend's module output names.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@source/tests/consistent/model/test_ener.py`:
- Around line 1184-1190: The variable nloc is assigned but never used—remove the
unused assignment or replace it with a meaningful use; specifically in the test
block where nframes, coords_2f, atype_2f, box_2f, natoms_data, and energy_data
are set, delete the declaration "nloc = 6" (or if intended, use nloc to drive
array shapes/validation) so there are no unused local variables in the test_ener
setup.
- Around line 124-126: The test contains a duplicate assignment to pd_class
(assigned to EnergyModelPD twice); remove the redundant second assignment so
pd_class is only set once and leave the other class assignments (pt_expt_class =
EnergyModelPTExpt and jax_class = EnergyModelJAX) unchanged.
- Around line 979-988: In test_change_out_bias the local variable nloc is
assigned but never used; remove the unused assignment (nloc = 6) from the
test_change_out_bias function to satisfy static analysis (or if the intended
intent was to use it, replace references to hardcoded 6 with nloc where
appropriate) so that only necessary variables (e.g. coords_2f, atype_2f, box_2f,
natoms_data, energy_data) remain.
🧹 Nitpick comments (6)
deepmd/dpmodel/fitting/general_fitting.py (1)

599-639: Minor: double forward pass when eval_return_middle_output is enabled.

When the hook is active, the network is evaluated twice per type (or once extra for mixed types) — once for the full output (line 609/635) and again via call_until_last (line 621/639). This duplicates computation of all layers except the last.

Since the dpmodel backend is primarily for reference/testing rather than production, this is acceptable. However, if performance becomes a concern, you could refactor the network to return both the final and penultimate outputs in a single pass.

deepmd/pt_expt/model/transform_output.py (1)

109-110: zip() without strict=True (ruff B905).

Both split_vv1 and split_svv1 are produced from the same size split, so they're guaranteed to have equal length. Adding strict=True makes this invariant explicit and guards against future refactors that might change one without the other.

Suggested fix
-    for vvi, svvi in zip(split_vv1, split_svv1):
+    for vvi, svvi in zip(split_vv1, split_svv1, strict=True):
source/tests/pt_expt/model/test_ener_model.py (1)

207-250: Consider extending DP-consistency test to also verify force values.

The consistency test validates energy and atom_energy but doesn't check force. Since the dpmodel sets derivative outputs to None, this is understandable, but you could compute a numerical finite-difference force from the dpmodel to cross-validate the autograd-based force from pt_expt. This would strengthen confidence in the derivative pathway.

(Note: the autodiff test file mentioned in the summary may already cover this — feel free to disregard if so.)

deepmd/pt_expt/model/ener_model.py (1)

100-147: forward_lower bakes fparam, aparam, do_atomic_virial into the traced graph.

The closure captures these values as constants during make_fx tracing, so the returned module is specialized for the specific fparam/aparam/do_atomic_virial values passed at trace time. This is appropriate for export workflows but means the caller must re-trace for different parameter configurations. The docstring correctly documents this ("Sample inputs with representative shapes"), but it might be worth explicitly noting that fparam/aparam/do_atomic_virial are baked in as well.

source/tests/pt_expt/model/test_autodiff.py (2)

49-76: Docstring says cell shape is [nf, 3, 3], but the body reshapes to [nframes, 9].

The function accepts cell with shape [nf, 3, 3] per the docstring, which is consistent with the callers (e.g., cell.unsqueeze(0) producing [1, 3, 3]). However, line 72 reshapes it to [nframes, 9]. This works, but the docstring could note that either [nf, 3, 3] or [nf, 9] is accepted, since reshape(nframes, 9) handles both shapes silently. Very minor — just a documentation clarity nit.


154-181: Duplicated setUp between TestEnergyModelSeAForce and TestEnergyModelSeAVirial.

Both test classes have identical setUp methods. Consider extracting a shared helper or a common base mixin to reduce duplication.

That said, this pattern (identical setUp in mixin-based test classes) is common in this codebase and the duplication is small, so this is a minor nit.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 1fa1eb27b2

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

@codecov
Copy link

codecov bot commented Feb 14, 2026

Codecov Report

❌ Patch coverage is 75.09579% with 130 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.06%. Comparing base (a0bd530) to head (53a7e7f).
⚠️ Report is 11 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/dpmodel/model/dp_zbl_model.py 40.32% 37 Missing ⚠️
deepmd/dpmodel/model/dipole_model.py 40.42% 28 Missing ⚠️
deepmd/dpmodel/model/dos_model.py 48.14% 14 Missing ⚠️
deepmd/dpmodel/model/polar_model.py 48.14% 14 Missing ⚠️
deepmd/dpmodel/model/property_model.py 48.14% 14 Missing ⚠️
deepmd/pt_expt/utils/network.py 75.75% 8 Missing ⚠️
deepmd/dpmodel/model/ener_model.py 91.30% 6 Missing ⚠️
deepmd/pt_expt/model/transform_output.py 93.90% 5 Missing ⚠️
deepmd/dpmodel/model/make_model.py 96.96% 1 Missing ⚠️
deepmd/dpmodel/model/spin_model.py 85.71% 1 Missing ⚠️
... and 2 more
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5243      +/-   ##
==========================================
- Coverage   82.12%   82.06%   -0.06%     
==========================================
  Files         736      740       +4     
  Lines       74237    74667     +430     
  Branches     3615     3616       +1     
==========================================
+ Hits        60966    61277     +311     
- Misses      12107    12225     +118     
- Partials     1164     1165       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@njzjz njzjz linked an issue Feb 14, 2026 that may be closed by this pull request
@njzjz
Copy link
Member

njzjz commented Feb 14, 2026

How do we handle model output keys in the new backend?

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@deepmd/dpmodel/atomic_model/dp_atomic_model.py`:
- Around line 217-221: The runtime check using assert in the block guarded by
self.enable_eval_fitting_last_layer_hook is unsafe because asserts are removed
with -O; replace it with an explicit runtime exception (e.g., raise RuntimeError
or ValueError) that checks "middle_output" in ret and raises a clear error
message, then pop the key and append to self.eval_fitting_last_layer_list as
before (refer to the symbols self.enable_eval_fitting_last_layer_hook,
ret.pop("middle_output"), and self.eval_fitting_last_layer_list to locate the
code).
🧹 Nitpick comments (1)
deepmd/dpmodel/atomic_model/dp_atomic_model.py (1)

140-148: Consider clearing the cache after retrieval to prevent stale data and unbounded growth.

Both eval_descriptor() and eval_fitting_last_layer() leave the cache intact after returning the concatenated result. If a caller invokes them, then runs more forward passes, a subsequent call silently includes data from both the old and new passes. If clearing isn't desired, at minimum document the accumulation semantics so callers know to call set_eval_*_hook(True) again to reset.

Also applies to: 156-164

@wanghan-iapcm
Copy link
Collaborator Author

How do we handle model output keys in the new backend?

Shouldn't it always use the output keys in the dpmodel backend?

@wanghan-iapcm wanghan-iapcm requested a review from njzjz February 15, 2026 11:52
Comment on lines +101 to +109
def call(
self,
coord: Array,
atype: Array,
box: Array | None = None,
fparam: Array | None = None,
aparam: Array | None = None,
do_atomic_virial: bool = False,
) -> dict[str, Array]:

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning

This method does not accept arbitrary keyword arguments, which overridden
NativeOP.call
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at most 7 positional arguments, whereas overridden
NativeOP.call
may be called with arbitrarily many.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at least 3 positional arguments, whereas overridden
NativeOP.call
may be called with 1.
This call
correctly calls the base method, but does not match the signature of the overriding method.
Comment on lines +56 to +64
def call(
self,
coord: Array,
atype: Array,
box: Array | None = None,
fparam: Array | None = None,
aparam: Array | None = None,
do_atomic_virial: bool = False,
) -> dict[str, Array]:

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning

This method does not accept arbitrary keyword arguments, which overridden
NativeOP.call
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at most 7 positional arguments, whereas overridden
NativeOP.call
may be called with arbitrarily many.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at least 3 positional arguments, whereas overridden
NativeOP.call
may be called with 1.
This call
correctly calls the base method, but does not match the signature of the overriding method.
Comment on lines +48 to +56
def call(
self,
coord: Array,
atype: Array,
box: Array | None = None,
fparam: Array | None = None,
aparam: Array | None = None,
do_atomic_virial: bool = False,
) -> dict[str, Array]:

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning

This method does not accept arbitrary keyword arguments, which overridden
NativeOP.call
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at most 7 positional arguments, whereas overridden
NativeOP.call
may be called with arbitrarily many.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at least 3 positional arguments, whereas overridden
NativeOP.call
may be called with 1.
This call
correctly calls the base method, but does not match the signature of the overriding method.
Comment on lines +58 to +66
def call(
self,
coord: Array,
atype: Array,
box: Array | None = None,
fparam: Array | None = None,
aparam: Array | None = None,
do_atomic_virial: bool = False,
) -> dict[str, Array]:

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning

This method does not accept arbitrary keyword arguments, which overridden
NativeOP.call
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at most 7 positional arguments, whereas overridden
NativeOP.call
may be called with arbitrarily many.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at least 3 positional arguments, whereas overridden
NativeOP.call
may be called with 1.
This call
correctly calls the base method, but does not match the signature of the overriding method.
Comment on lines +48 to +56
def call(
self,
coord: Array,
atype: Array,
box: Array | None = None,
fparam: Array | None = None,
aparam: Array | None = None,
do_atomic_virial: bool = False,
) -> dict[str, Array]:

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning

This method does not accept arbitrary keyword arguments, which overridden
NativeOP.call
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at most 7 positional arguments, whereas overridden
NativeOP.call
may be called with arbitrarily many.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at least 3 positional arguments, whereas overridden
NativeOP.call
may be called with 1.
This call
correctly calls the base method, but does not match the signature of the overriding method.
Comment on lines +51 to +59
def call(
self,
coord: Array,
atype: Array,
box: Array | None = None,
fparam: Array | None = None,
aparam: Array | None = None,
do_atomic_virial: bool = False,
) -> dict[str, Array]:

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning

This method does not accept arbitrary keyword arguments, which overridden
NativeOP.call
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at most 7 positional arguments, whereas overridden
NativeOP.call
may be called with arbitrarily many.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at least 3 positional arguments, whereas overridden
NativeOP.call
may be called with 1.
This call
correctly calls the base method, but does not match the signature of the overriding method.
@wanghan-iapcm wanghan-iapcm marked this pull request as draft February 16, 2026 12:02
@wanghan-iapcm wanghan-iapcm changed the title feat(pt_expt): full model feat(pt_expt): full model and refact the module output names of dpmodel backend Feb 16, 2026
@wanghan-iapcm wanghan-iapcm added the Test CUDA Trigger test CUDA workflow label Feb 16, 2026
@github-actions github-actions bot removed the Test CUDA Trigger test CUDA workflow label Feb 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Full Model and Autograd Support (PyTorch Exportable)

2 participants