diff --git a/API_scheme.md b/API_scheme.md
new file mode 100644
index 000000000..800a04741
--- /dev/null
+++ b/API_scheme.md
@@ -0,0 +1,49 @@
+
PINA Code Structure
+
+
+Here is a high-level overview of PINA’s main modules. For full details, refer to the
+documentation.
+
+```mermaid
+flowchart TB
+ PINA["pina
The basic module including `Condition`, LabelTensor, `Graph` and `Trainer` API"]
+
+ subgraph R1[" "]
+ direction LR
+ PROB["pina.problem
Module for defining problems via base class inheritance"]
+ MODEL["pina.model
Module for built-in PyTorch models full architectures"]
+ SOLVER["pina.solver
Module for built-in solvers and abstract interfaces"]
+ CALLBACK["pina.callback
Module for built-in callbacks to integrate training pipelines"]
+ end
+
+ subgraph R2[" "]
+ direction LR
+ DOMAIN["pina.domain
Module for defining geometries and set operations"]
+ BLOCK["pina.block
Module for built-in PyTorch models layers only"]
+ OPTIM["pina.optim
Module for build or import optimizers and schedulers"]
+ DATA["pina.data
Module for DataModules for data processing"]
+ end
+
+ subgraph R3[" "]
+ direction LR
+ OPERATOR["pina.operator
Module for differential operators"]
+ ADAPT["pina.adaptive_function
Module for PyTorch learnable activations"]
+ LOSS["pina.loss
Module for losses and weighting strategies"]
+ CONDITION["pina.condition
Module for model training constraints"]
+ end
+
+ PINA --> PROB
+ PINA --> MODEL
+ PINA --> SOLVER
+ PINA --> CALLBACK
+
+ PROB --> DOMAIN
+ MODEL --> BLOCK
+ SOLVER --> OPTIM
+ CALLBACK --> DATA
+
+ DOMAIN --> OPERATOR
+ BLOCK --> ADAPT
+ OPTIM --> LOSS
+ DATA --> CONDITION
+```
\ No newline at end of file
diff --git a/README.md b/README.md
index 81a256d70..3369cc5d5 100644
--- a/README.md
+++ b/README.md
@@ -4,38 +4,87 @@ Copyright Contributors to the Pyro project.
SPDX-License-Identifier: Apache-2.0
-->
-
-
-
-
-
-
- |
-
-
- A Unified Framework for Scientific Machine Learning
-
- |
-
-
-
-
------------------------------------------
-
-[](https://github.com/mathLab/PINA/actions/workflows/pages/pages-build-deployment)
-[](https://pypi.org/project/pina-mathlab/)
-[](https://pypi.org/project/pina-mathlab/)
-[](https://joss.theoj.org/papers/10.21105/joss.05352)
-[](https://github.com/mathLab/PINA/blob/main/LICENSE.rst)
-
-
-[Getting Started](https://github.com/mathLab/PINA/tree/master/tutorials#pina-tutorials) |
-[Documentation](https://mathlab.github.io/PINA/) |
-[Contributing](https://github.com/mathLab/PINA/blob/master/CONTRIBUTING.md)
-
-**PINA** is an open-source Python library designed to simplify and accelerate the development of Scientific Machine Learning (SciML) solutions. Built on top of [PyTorch](https://pytorch.org/), [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), and [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/), PINA provides an intuitive framework for defining, experimenting with, and solving complex problems using Neural Networks, Physics-Informed Neural Networks (PINNs), Neural Operators, and more.
+
+
+
+
+
+
+
+
+ 
+
+
+
+
+
+
+
+
+
+
+
+ A Unified Framework for Scientific Machine Learning
+
+
+
+
+
+ PINA is an open-source Python library designed to simplify and accelerate the development of
+ Scientific Machine Learning (SciML) solutions, including PINNs, Neural Operators,
+ data-driven modeling, and more.
+
+
+
+
Built on top of
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+News & Announcements
+
+
+
+ -
+ [v0.3] – New solvers: autoregressive solver for sequential prediction tasks and multi-model solver support. Internals redesigned around a mixin architecture — lightweight, single-responsibility mixins (preprocessing, forward, postprocessing) that can be freely composed, with residual computation and loss aggregation clearly separated.
+
+ -
+ [v0.3] – Conditions refactoring: evaluation logic moved out of the solver and into the condition itself via a dedicated
evaluate method, decoupling the training loop from problem-specific logic and enabling fully modular, solver-agnostic conditions.
+
+ -
+ [v0.3] – Code cleanup: core internals migrated to the
_src pattern; interfaces and base classes introduced across conditions, problems (AbstractProblem → BaseProblem), losses, and data module; equation zoo reorganized with Burgers added.
+
+ -
+ [v0.3] – KAN support: Kolmogorov–Arnold Networks with fully vectorized spline basis and analytical derivatives.
+
+
+
+
+
+ Want the full history?
+ See the Releases page.
+
+
+
+
+What's PINA
+
+PINA provides an intuitive framework for defining, experimenting with, and solving complex problems using Neural Networks, Physics-Informed Neural Networks (PINNs), Neural Operators, and more.
- **Modular Architecture**: Designed with modularity in mind and relying on powerful yet composable abstractions, PINA allows users to easily plug, replace, or extend components, making experimentation and customization straightforward.
@@ -45,51 +94,128 @@ SPDX-License-Identifier: Apache-2.0
-## Installation
+
-### Installing a stable PINA release
-**Install using pip:**
-```sh
-pip install "pina-mathlab"
-```
-**Install from source:**
-```sh
-git clone https://github.com/mathLab/PINA
-cd PINA
-git checkout master
-pip install .
-```
+
-**Install with extra packages:**
+
+
+ Installation
+
+
+
+
+
-To install extra dependencies required to run tests or tutorials directories, please use the following command:
-```sh
-pip install "pina-mathlab[extras]"
-```
-Available extras include:
-* `dev` for development purpuses, use this if you want to [Contribute](https://github.com/mathLab/PINA/blob/master/CONTRIBUTING.md#contributing-to-pina).
-* `test` for running test locally.
-* `doc` for building documentation locally.
-* `tutorial` for running [Tutorials](https://github.com/mathLab/PINA/tree/master/tutorials#pina-tutorials).
+Install a stable release
-## Quick Tour for New Users
-Solving a differential problem in **PINA** follows the *four steps pipeline*:
+pip install "pina-mathlab"
-1. Define the problem to be solved with its constraints using the [Problem API](https://mathlab.github.io/PINA/_rst/_code.html#problems).
+Install from source
-2. Design your model using PyTorch, or for graph-based problems, leverage PyTorch Geometric to build Graph Neural Networks. You can also import models directly from the [Model API](https://mathlab.github.io/PINA/_rst/_code.html#models).
+git clone https://github.com/mathLab/PINA
+cd PINA
+git checkout master
+pip install .
+
+
+Install with extra dependencies
+
+
+To install additional packages required for development, tests, docs, or tutorials:
+
+
+pip install "pina-mathlab[extras]"
+
+Available extras:
+
+
+ dev for development purposes
+ test for running tests locally
+ doc for building documentation locally
+ tutorial for running tutorials
+
+
+
+
+
+
+
+
+ Getting started with PINA
+
+
+
+
+
+
+
+
+
+
+Solving a differential problem in PINA follows a clean four-step pipeline:
+
+
+
+ -
+ Define the problem and constraints using the
+ Problem API.
+
+ -
+ Design your model using PyTorch, PyTorch Geometric, or import from the
+ Model API.
+
+ -
+ Select or build a Solver using the
+ Solver API.
+
+ -
+ Train with the
+ Trainer API,
+ powered by PyTorch Lightning.
+
+
+
+```mermaid
+flowchart LR
+ STEP1["Problem and Data
Define the mathematical problem
Identify constraints or import data"]
+ STEP2["Model Design
Build a PyTorch module Choose or customize a model"]
+ STEP3["Solver Selection
Use available solvers or define your own strategy"]
+ STEP4["Training
Optimize the model with PyTorch Lightning"]
+
+ STEP1 e1@--> STEP2
+ STEP2 e2@--> STEP3
+ STEP3 e3@--> STEP4
+ e1@{ animate: true }
+ e2@{ animate: true }
+ e3@{ animate: true }
+```
-3. Select or build a Solver for the Problem, e.g., supervised solvers, or physics-informed (e.g., PINN) solvers. [PINA Solvers](https://mathlab.github.io/PINA/_rst/_code.html#solvers) are modular and can be used as-is or customized.
+
+Want to dive deeper? Check out the official
+Tutorials.
+
-4. Train the model using the [Trainer API](https://mathlab.github.io/PINA/_rst/trainer.html) class, built on PyTorch Lightning, which supports efficient, scalable training with advanced features.
+
+
-Do you want to learn more about it? Look at our [Tutorials](https://github.com/mathLab/PINA/tree/master/tutorials#pina-tutorials).
+
+
+ PINA by Examples
+
+
+
+
+
-### Solve Data Driven Problems
-Data driven modelling aims to learn a function that given some input data gives an output (e.g. regression, classification, ...). In PINA you can easily do this by:
-```python
+
+Data-Driven Modeling Example
+
+```python
import torch
from pina import Trainer
from pina.model import FeedForward
@@ -101,16 +227,28 @@ target_tensor = input_tensor.pow(3)
# Step 1. Define problem
problem = SupervisedProblem(input_tensor, target_tensor)
-# Step 2. Design model (you can use your favourite torch.nn.Module in here)
-model = FeedForward(input_dimensions=1, output_dimensions=1, layers=[64, 64])
-# Step 3. Define Solver
-solver = SupervisedSolver(problem, model, use_lt=False)
+
+# Step 2. Define model
+model = FeedForward(input_dimensions=1, output_dimensions=1, layers=[64, 64])
+
+# Step 3. Define solver
+solver = SupervisedSolver(problem, model, use_lt=False)
+
# Step 4. Train
-trainer = Trainer(solver, max_epochs=1000, accelerator='gpu')
+trainer = Trainer(solver, max_epochs=1000, accelerator="gpu")
trainer.train()
```
-### Solve Physics Informed Problems
-Physics-informed modeling aims to learn functions that not only fit data, but also satisfy known physical laws, such as differential equations or boundary conditions. For example, the following differential problem:
+
+
+
+
+
+
+Physics-Informed Example
+
+
+Consider the following differential problem:
+
$$
\begin{cases}
@@ -118,8 +256,9 @@ $$
u(x=0) &= 1
\end{cases}
$$
-
-in PINA, can be easily implemented by:
+
+In PINA, this can be implemented as:
+
```python
from pina import Trainer, Condition
@@ -135,7 +274,6 @@ def ode_equation(input_, output_):
u = output_.extract(["u"])
return u_x - u
-# build the problem
class SimpleODE(SpatialProblem):
output_variables = ["u"]
spatial_domain = CartesianDomain({"x": [0, 1]})
@@ -151,52 +289,81 @@ class SimpleODE(SpatialProblem):
# Step 1. Define problem
problem = SimpleODE()
problem.discretise_domain(n=100, mode="grid", domains=["D", "x0"])
-# Step 2. Design model (you can use your favourite torch.nn.Module in here)
-model = FeedForward(input_dimensions=1, output_dimensions=1, layers=[64, 64])
-# Step 3. Define Solver
-solver = PINN(problem, model)
-# Step 4. Train
-trainer = Trainer(solver, max_epochs=1000, accelerator='gpu')
-trainer.train()
-```
-
-## Application Programming Interface
-Here's a quick look at PINA's main module. For a better experience and full details, check out the [documentation](https://mathlab.github.io/PINA/).
-
-
-
-
-
-## Contributing and Community
-
-We would love to develop PINA together with our community! Best way to get started is to select any issue from the [`good-first-issue` label](https://github.com/mathLab/PINA/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22). If you would like to contribute, please review our [Contributing Guide](CONTRIBUTING.md) for all relevant details.
-
-We warmly thank all the contributors that have supported PINA so far:
-
-
-
-
-Made with [contrib.rocks](https://contrib.rocks).
-
-## Citation
-If **PINA** has been significant in your research, and you would like to acknowledge the project in your academic publication, we suggest citing the following paper:
+# Step 2. Define model
+model = FeedForward(input_dimensions=1, output_dimensions=1, layers=[64, 64])
-```
-Coscia, D., Ivagnes, A., Demo, N., & Rozza, G. (2023). Physics-Informed Neural networks for Advanced modeling. Journal of Open Source Software, 8(87), 5352.
-```
+# Step 3. Define solver
+solver = PINN(problem, model)
-Or in BibTex format
-```
-@article{coscia2023physics,
- title={Physics-Informed Neural networks for Advanced modeling},
- author={Coscia, Dario and Ivagnes, Anna and Demo, Nicola and Rozza, Gianluigi},
- journal={Journal of Open Source Software},
- volume={8},
- number={87},
- pages={5352},
- year={2023}
- }
+# Step 4. Train
+trainer = Trainer(solver, max_epochs=1000, accelerator="gpu")
+trainer.train()
```
+
+
+
+
+
+
+ Contributing & Community
+
+
+
+
+We would love to develop PINA together with the community.
+A great place to start is the list of
+
+ good-first-issue
+
+issues.
+
+
+
+If you would like to contribute, please read the
+Contributing Guide.
+
+
+
+
+
+
+
+
+
+ Made with contrib.rocks.
+
+
+
+
+
+
+
+ Citation
+
+
+
+
+
+
+If PINA has been significant in your research and you would like to acknowledge it, please cite:
+
+
+Coscia, D., Ivagnes, A., Demo, N., & Rozza, G. (2023).
+Physics-Informed Neural networks for Advanced modeling.
+Journal of Open Source Software, 8(87), 5352.
+
+Or in BibTeX format:
+
+@article{coscia2023physics,
+ title={Physics-Informed Neural networks for Advanced modeling},
+ author={Coscia, Dario and Ivagnes, Anna and Demo, Nicola and Rozza, Gianluigi},
+ journal={Journal of Open Source Software},
+ volume={8},
+ number={87},
+ pages={5352},
+ year={2023}
+}
+
diff --git a/SECURITY.md b/SECURITY.md
index b1dfe91f8..a425740c2 100644
--- a/SECURITY.md
+++ b/SECURITY.md
@@ -10,6 +10,7 @@ Security fixes are given priority and might be enough to cause a new version to
| Version | Supported |
| ------- | ------------------ |
+| 0.3 | ✅ |
| 0.2 | ✅ |
| 0.1 | ✅ |
diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst
index 64d88bc8b..6b2111946 100644
--- a/docs/source/_rst/_code.rst
+++ b/docs/source/_rst/_code.rst
@@ -15,17 +15,20 @@ The pipeline to solve differential equations with PINA follows just five steps:
2. Generate data using built in `Geometrical Domains`_, or load high level simulation results as :doc:`LabelTensor `
3. Choose or build one or more `Models`_ to solve the problem
4. Choose a solver across PINA available `Solvers`_, or build one using the :doc:`SolverInterface `
- 5. Train the model with the PINA :doc:`Trainer `, enhance the train with `Callbacks`_
+ 5. Train the model with the PINA :doc:`Trainer `, enhance the train with `Callbacks`_
-Trainer, Dataset and Datamodule
---------------------------------
+Trainer, Data Loader and Data Module
+----------------------------------------
.. toctree::
:titlesonly:
Trainer
- Dataset
- DataModule
+ Data Module
+ Single-Batch Data Loader
+ Aggregator
+ Creator
+ Condition Subset
Data Types
------------
@@ -52,37 +55,65 @@ Conditions
.. toctree::
:titlesonly:
- ConditionInterface
+ Condition Interface
+ Base Condition
Condition
- DataCondition
- DomainEquationCondition
- InputEquationCondition
- InputTargetCondition
+ Data Condition
+ Domain Equation Condition
+ Graph Time Series Condition
+ Input Equation Condition
+ Input Target Condition
+ Time Series Condition
+
+Batch and Data Managers
+--------------------------
+.. toctree::
+ :titlesonly:
+
+ Batch Manager
+ Data Manager Interface
+ Data Manager
+ Graph Data Manager
+ Tensor Data Manager
Solvers
---------------
+------------------------
.. toctree::
:titlesonly:
- SolverInterface
- SingleSolverInterface
- MultiSolverInterface
- SupervisedSolverInterface
- DeepEnsembleSolverInterface
- PINNInterface
- PINN
- GradientPINN
- CausalPINN
- CompetitivePINN
- SelfAdaptivePINN
- RBAPINN
- DeepEnsemblePINN
- SupervisedSolver
- DeepEnsembleSupervisedSolver
- ReducedOrderModelSolver
- GAROM
+ Solver Interface
+ Base Solver
+ Single-Model Solver
+ Multi-Model Solver
+ Ensemble Solver
+ Supervised Single-Model Solver
+ Supervised Ensemble Solver
+ Physics-Informed Single-Model Solver
+ Physics-Informed Ensemble Solver
+ Autoregressive Single-Model Solver
+ Autoregressive Ensemble Solver
+ Self-Adaptive Physics-Informed Solver
+ Competitive Physics-Informed Solver
+ Gradient Physics-Informed Single-Model Solver
+ RBA Physics-Informed Single-Model Solver
+ Causal Physics-Informed Single-Model Solver
+
+Mixins
+------------------------
+
+.. toctree::
+ :titlesonly:
+ Single-Model Mixin
+ Multi-Model Mixin
+ Ensemble Mixin
+ Condition Aggregator Mixin
+ Manual Optimization Mixin
+ Physics-Informed Mixin
+ Autoregressive Mixin
+ Gradient-Enhanced Mixin
+ Residual-Based Attention Mixin
Models
------------
@@ -108,6 +139,8 @@ Models
PirateNet
EquivariantGraphNeuralOperator
SINDy
+ Vectorized Spline
+ Kolmogorov-Arnold Network
Blocks
-------------
@@ -126,6 +159,7 @@ Blocks
Continuous Convolution Block
Orthogonal Block
PirateNet Block
+ KAN Block
Message Passing
-------------------
@@ -157,31 +191,32 @@ Optimizers and Schedulers
.. toctree::
:titlesonly:
- Optimizer
- Scheduler
- TorchOptimizer
- TorchScheduler
+ Optimizer Interface
+ Scheduler Interface
+ Torch Optimizer
+ Torch Scheduler
-Adaptive Activation Functions
+Adaptive Functions
-------------------------------
.. toctree::
:titlesonly:
- Adaptive Function Interface
- Adaptive ReLU
- Adaptive Sigmoid
- Adaptive Tanh
- Adaptive SiLU
- Adaptive Mish
- Adaptive ELU
- Adaptive CELU
- Adaptive GELU
- Adaptive Softmin
- Adaptive Softmax
- Adaptive SIREN
- Adaptive Exp
+ Adaptive Function Interface
+ Base Adaptive Function
+ Adaptive CELU
+ Adaptive ELU
+ Adaptive Exp
+ Adaptive GELU
+ Adaptive Mish
+ Adaptive ReLU
+ Adaptive Sigmoid
+ Adaptive SiLU
+ Adaptive SIREN
+ Adaptive Softmax
+ Adaptive Softmin
+ Adaptive Tanh
Equations and Differential Operators
@@ -190,39 +225,60 @@ Equations and Differential Operators
.. toctree::
:titlesonly:
- EquationInterface
+ Equation Interface
+ Base Equation
Equation
- SystemEquation
- Equation Factory
+ System Equation
Differential Operators
+Equation Zoo
+---------------------------------------
+
+.. toctree::
+ :titlesonly:
+
+ Acoustic Wave Equation
+ Advection Equation
+ Allen-Cahn Equation
+ Burgers' Equation
+ Diffusion-Reaction Equation
+ Fixed Flux
+ Fixed Gradient
+ Fixed Laplacian
+ Fixed Value
+ Helmholtz Equation
+ Poisson Equation
+
+
Problems
--------------
.. toctree::
:titlesonly:
- AbstractProblem
+ ProblemInterface
+ BaseProblem
InverseProblem
ParametricProblem
SpatialProblem
TimeDependentProblem
-Problems Zoo
+Problem Zoo
--------------
.. toctree::
:titlesonly:
- AcousticWaveProblem
- AdvectionProblem
- AllenCahnProblem
- DiffusionReactionProblem
- HelmholtzProblem
- InversePoisson2DSquareProblem
- Poisson2DSquareProblem
- SupervisedProblem
+ Acoustic Wave Problem
+ Advection Problem
+ Allen-Cahn Problem
+ Burgers' Problem
+ Diffusion-Reaction Problem
+ Helmholtz Problem
+ Inverse Poisson 2D Square Problem
+ Poisson 2D Square Problem
+ Supervised Problem
Geometrical Domains
@@ -258,23 +314,37 @@ Callbacks
Switch Optimizer
Switch Scheduler
- Normalizer Data
- PINA Progress Bar
- Metric Tracker
Refinement Interface
+ Base Refinement
R3 Refinement
+ Data Normalizer
+ Metric Tracker
+ PINA Progress Bar
+
+
+Losses
+---------
-Losses and Weightings
----------------------
+.. toctree::
+ :titlesonly:
+
+ DualLossInterface
+ BaseDualLoss
+ LpLoss
+ PowerLoss
+ SinkhornLoss
+
+
+Weighting Schemas
+--------------------
.. toctree::
:titlesonly:
- LossInterface
- LpLoss
- PowerLoss
- WeightingInterface
- ScalarWeighting
- NeuralTangentKernelWeighting
- SelfAdaptiveWeighting
- LinearWeighting
\ No newline at end of file
+ Weighting Interface
+ Base Weighting
+ Linear Weighting
+ Neural-Tangent-Kernel Weighting
+ No Weighting
+ Scalar Weighting
+ Self-Adaptive Weighting
diff --git a/docs/source/_rst/adaptive_function/AdaptiveActivationFunctionInterface.rst b/docs/source/_rst/adaptive_function/AdaptiveActivationFunctionInterface.rst
deleted file mode 100644
index cf8b6551d..000000000
--- a/docs/source/_rst/adaptive_function/AdaptiveActivationFunctionInterface.rst
+++ /dev/null
@@ -1,8 +0,0 @@
-AdaptiveActivationFunctionInterface
-=======================================
-
-.. currentmodule:: pina.adaptive_function.adaptive_function_interface
-
-.. automodule:: pina.adaptive_function.adaptive_function_interface
- :members:
- :show-inheritance:
diff --git a/docs/source/_rst/adaptive_function/AdaptiveCELU.rst b/docs/source/_rst/adaptive_function/AdaptiveCELU.rst
deleted file mode 100644
index c4d6d5429..000000000
--- a/docs/source/_rst/adaptive_function/AdaptiveCELU.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-AdaptiveCELU
-============
-
-.. currentmodule:: pina.adaptive_function.adaptive_function
-
-.. autoclass:: AdaptiveCELU
- :members:
- :show-inheritance:
- :inherited-members: AdaptiveActivationFunctionInterface
diff --git a/docs/source/_rst/adaptive_function/AdaptiveELU.rst b/docs/source/_rst/adaptive_function/AdaptiveELU.rst
deleted file mode 100644
index aab273b08..000000000
--- a/docs/source/_rst/adaptive_function/AdaptiveELU.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-AdaptiveELU
-===========
-
-.. currentmodule:: pina.adaptive_function.adaptive_function
-
-.. autoclass:: AdaptiveELU
- :members:
- :show-inheritance:
- :inherited-members: AdaptiveActivationFunctionInterface
diff --git a/docs/source/_rst/adaptive_function/AdaptiveExp.rst b/docs/source/_rst/adaptive_function/AdaptiveExp.rst
deleted file mode 100644
index a7ee52b20..000000000
--- a/docs/source/_rst/adaptive_function/AdaptiveExp.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-AdaptiveExp
-===========
-
-.. currentmodule:: pina.adaptive_function.adaptive_function
-
-.. autoclass:: AdaptiveExp
- :members:
- :show-inheritance:
- :inherited-members: AdaptiveActivationFunctionInterface
diff --git a/docs/source/_rst/adaptive_function/AdaptiveGELU.rst b/docs/source/_rst/adaptive_function/AdaptiveGELU.rst
deleted file mode 100644
index b4aef14dc..000000000
--- a/docs/source/_rst/adaptive_function/AdaptiveGELU.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-AdaptiveGELU
-============
-
-.. currentmodule:: pina.adaptive_function.adaptive_function
-
-.. autoclass:: AdaptiveGELU
- :members:
- :show-inheritance:
- :inherited-members: AdaptiveActivationFunctionInterface
diff --git a/docs/source/_rst/adaptive_function/AdaptiveMish.rst b/docs/source/_rst/adaptive_function/AdaptiveMish.rst
deleted file mode 100644
index d006df054..000000000
--- a/docs/source/_rst/adaptive_function/AdaptiveMish.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-AdaptiveMish
-============
-
-.. currentmodule:: pina.adaptive_function.adaptive_function
-
-.. autoclass:: AdaptiveMish
- :members:
- :show-inheritance:
- :inherited-members: AdaptiveActivationFunctionInterface
diff --git a/docs/source/_rst/adaptive_function/AdaptiveReLU.rst b/docs/source/_rst/adaptive_function/AdaptiveReLU.rst
deleted file mode 100644
index d0fe4de68..000000000
--- a/docs/source/_rst/adaptive_function/AdaptiveReLU.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-AdaptiveReLU
-============
-
-.. currentmodule:: pina.adaptive_function.adaptive_function
-
-.. autoclass:: AdaptiveReLU
- :members:
- :show-inheritance:
- :inherited-members: AdaptiveActivationFunctionInterface
diff --git a/docs/source/_rst/adaptive_function/AdaptiveSIREN.rst b/docs/source/_rst/adaptive_function/AdaptiveSIREN.rst
deleted file mode 100644
index 9f132547b..000000000
--- a/docs/source/_rst/adaptive_function/AdaptiveSIREN.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-AdaptiveSIREN
-=============
-
-.. currentmodule:: pina.adaptive_function.adaptive_function
-
-.. autoclass:: AdaptiveSIREN
- :members:
- :show-inheritance:
- :inherited-members: AdaptiveActivationFunctionInterface
diff --git a/docs/source/_rst/adaptive_function/AdaptiveSiLU.rst b/docs/source/_rst/adaptive_function/AdaptiveSiLU.rst
deleted file mode 100644
index 722678611..000000000
--- a/docs/source/_rst/adaptive_function/AdaptiveSiLU.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-AdaptiveSiLU
-============
-
-.. currentmodule:: pina.adaptive_function.adaptive_function
-
-.. autoclass:: AdaptiveSiLU
- :members:
- :show-inheritance:
- :inherited-members: AdaptiveActivationFunctionInterface
diff --git a/docs/source/_rst/adaptive_function/AdaptiveSigmoid.rst b/docs/source/_rst/adaptive_function/AdaptiveSigmoid.rst
deleted file mode 100644
index 6002ffb31..000000000
--- a/docs/source/_rst/adaptive_function/AdaptiveSigmoid.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-AdaptiveSigmoid
-===============
-
-.. currentmodule:: pina.adaptive_function.adaptive_function
-
-.. autoclass:: AdaptiveSigmoid
- :members:
- :show-inheritance:
- :inherited-members: AdaptiveActivationFunctionInterface
diff --git a/docs/source/_rst/adaptive_function/AdaptiveSoftmax.rst b/docs/source/_rst/adaptive_function/AdaptiveSoftmax.rst
deleted file mode 100644
index c2b4c9f09..000000000
--- a/docs/source/_rst/adaptive_function/AdaptiveSoftmax.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-AdaptiveSoftmax
-===============
-
-.. currentmodule:: pina.adaptive_function.adaptive_function
-
-.. autoclass:: AdaptiveSoftmax
- :members:
- :show-inheritance:
- :inherited-members: AdaptiveActivationFunctionInterface
diff --git a/docs/source/_rst/adaptive_function/AdaptiveSoftmin.rst b/docs/source/_rst/adaptive_function/AdaptiveSoftmin.rst
deleted file mode 100644
index 5189cb391..000000000
--- a/docs/source/_rst/adaptive_function/AdaptiveSoftmin.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-AdaptiveSoftmin
-===============
-
-.. currentmodule:: pina.adaptive_function.adaptive_function
-
-.. autoclass:: AdaptiveSoftmin
- :members:
- :show-inheritance:
- :inherited-members: AdaptiveActivationFunctionInterface
diff --git a/docs/source/_rst/adaptive_function/AdaptiveTanh.rst b/docs/source/_rst/adaptive_function/AdaptiveTanh.rst
deleted file mode 100644
index 9a9b380a3..000000000
--- a/docs/source/_rst/adaptive_function/AdaptiveTanh.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-AdaptiveTanh
-============
-
-.. currentmodule:: pina.adaptive_function.adaptive_function
-
-.. autoclass:: AdaptiveTanh
- :members:
- :show-inheritance:
- :inherited-members: AdaptiveActivationFunctionInterface
diff --git a/docs/source/_rst/adaptive_function/adaptive_celu.rst b/docs/source/_rst/adaptive_function/adaptive_celu.rst
new file mode 100644
index 000000000..b04bcf42b
--- /dev/null
+++ b/docs/source/_rst/adaptive_function/adaptive_celu.rst
@@ -0,0 +1,9 @@
+Adaptive CELU
+==================
+.. currentmodule:: pina.adaptive_function.adaptive_celu
+
+.. automodule:: pina._src.adaptive_function.adaptive_celu
+
+.. autoclass:: pina._src.adaptive_function.adaptive_celu.AdaptiveCELU
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/adaptive_function/adaptive_elu.rst b/docs/source/_rst/adaptive_function/adaptive_elu.rst
new file mode 100644
index 000000000..e758b20b3
--- /dev/null
+++ b/docs/source/_rst/adaptive_function/adaptive_elu.rst
@@ -0,0 +1,9 @@
+Adaptive ELU
+=============================
+.. currentmodule:: pina.adaptive_function.adaptive_elu
+
+.. automodule:: pina._src.adaptive_function.adaptive_elu
+
+.. autoclass:: pina._src.adaptive_function.adaptive_elu.AdaptiveELU
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/adaptive_function/adaptive_exp.rst b/docs/source/_rst/adaptive_function/adaptive_exp.rst
new file mode 100644
index 000000000..3feeb6192
--- /dev/null
+++ b/docs/source/_rst/adaptive_function/adaptive_exp.rst
@@ -0,0 +1,9 @@
+Adaptive Exp
+=============================
+.. currentmodule:: pina.adaptive_function.adaptive_exp
+
+.. automodule:: pina._src.adaptive_function.adaptive_exp
+
+.. autoclass:: pina._src.adaptive_function.adaptive_exp.AdaptiveExp
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/adaptive_function/adaptive_function_interface.rst b/docs/source/_rst/adaptive_function/adaptive_function_interface.rst
new file mode 100644
index 000000000..e7859c0d2
--- /dev/null
+++ b/docs/source/_rst/adaptive_function/adaptive_function_interface.rst
@@ -0,0 +1,9 @@
+Adaptive Function Interface
+=============================
+.. currentmodule:: pina.adaptive_function.adaptive_function_interface
+
+.. automodule:: pina._src.adaptive_function.adaptive_function_interface
+
+.. autoclass:: pina._src.adaptive_function.adaptive_function_interface.AdaptiveFunctionInterface
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/adaptive_function/adaptive_gelu.rst b/docs/source/_rst/adaptive_function/adaptive_gelu.rst
new file mode 100644
index 000000000..a07960373
--- /dev/null
+++ b/docs/source/_rst/adaptive_function/adaptive_gelu.rst
@@ -0,0 +1,9 @@
+Adaptive GELU
+=============================
+.. currentmodule:: pina.adaptive_function.adaptive_gelu
+
+.. automodule:: pina._src.adaptive_function.adaptive_gelu
+
+.. autoclass:: pina._src.adaptive_function.adaptive_gelu.AdaptiveGELU
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/adaptive_function/adaptive_mish.rst b/docs/source/_rst/adaptive_function/adaptive_mish.rst
new file mode 100644
index 000000000..f56c911fb
--- /dev/null
+++ b/docs/source/_rst/adaptive_function/adaptive_mish.rst
@@ -0,0 +1,9 @@
+Adaptive Mish
+=============================
+.. currentmodule:: pina.adaptive_function.adaptive_mish
+
+.. automodule:: pina._src.adaptive_function.adaptive_mish
+
+.. autoclass:: pina._src.adaptive_function.adaptive_mish.AdaptiveMish
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/adaptive_function/adaptive_relu.rst b/docs/source/_rst/adaptive_function/adaptive_relu.rst
new file mode 100644
index 000000000..a2032f344
--- /dev/null
+++ b/docs/source/_rst/adaptive_function/adaptive_relu.rst
@@ -0,0 +1,9 @@
+Adaptive ReLU
+=============================
+.. currentmodule:: pina.adaptive_function.adaptive_relu
+
+.. automodule:: pina._src.adaptive_function.adaptive_relu
+
+.. autoclass:: pina._src.adaptive_function.adaptive_relu.AdaptiveReLU
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/adaptive_function/adaptive_sigmoid.rst b/docs/source/_rst/adaptive_function/adaptive_sigmoid.rst
new file mode 100644
index 000000000..8aef91c0d
--- /dev/null
+++ b/docs/source/_rst/adaptive_function/adaptive_sigmoid.rst
@@ -0,0 +1,9 @@
+Adaptive Sigmoid
+=============================
+.. currentmodule:: pina.adaptive_function.adaptive_sigmoid
+
+.. automodule:: pina._src.adaptive_function.adaptive_sigmoid
+
+.. autoclass:: pina._src.adaptive_function.adaptive_sigmoid.AdaptiveSigmoid
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/adaptive_function/adaptive_silu.rst b/docs/source/_rst/adaptive_function/adaptive_silu.rst
new file mode 100644
index 000000000..2d22dcf20
--- /dev/null
+++ b/docs/source/_rst/adaptive_function/adaptive_silu.rst
@@ -0,0 +1,9 @@
+Adaptive SiLU
+=============================
+.. currentmodule:: pina.adaptive_function.adaptive_silu
+
+.. automodule:: pina._src.adaptive_function.adaptive_silu
+
+.. autoclass:: pina._src.adaptive_function.adaptive_silu.AdaptiveSiLU
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/adaptive_function/adaptive_siren.rst b/docs/source/_rst/adaptive_function/adaptive_siren.rst
new file mode 100644
index 000000000..167cd79ff
--- /dev/null
+++ b/docs/source/_rst/adaptive_function/adaptive_siren.rst
@@ -0,0 +1,9 @@
+Adaptive SIREN
+=============================
+.. currentmodule:: pina.adaptive_function.adaptive_siren
+
+.. automodule:: pina._src.adaptive_function.adaptive_siren
+
+.. autoclass:: pina._src.adaptive_function.adaptive_siren.AdaptiveSIREN
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/adaptive_function/adaptive_softmax.rst b/docs/source/_rst/adaptive_function/adaptive_softmax.rst
new file mode 100644
index 000000000..8797acae9
--- /dev/null
+++ b/docs/source/_rst/adaptive_function/adaptive_softmax.rst
@@ -0,0 +1,9 @@
+Adaptive Softmax
+=============================
+.. currentmodule:: pina.adaptive_function.adaptive_softmax
+
+.. automodule:: pina._src.adaptive_function.adaptive_softmax
+
+.. autoclass:: pina._src.adaptive_function.adaptive_softmax.AdaptiveSoftmax
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/adaptive_function/adaptive_softmin.rst b/docs/source/_rst/adaptive_function/adaptive_softmin.rst
new file mode 100644
index 000000000..72ed8ae1f
--- /dev/null
+++ b/docs/source/_rst/adaptive_function/adaptive_softmin.rst
@@ -0,0 +1,9 @@
+Adaptive Softmin
+=============================
+.. currentmodule:: pina.adaptive_function.adaptive_softmin
+
+.. automodule:: pina._src.adaptive_function.adaptive_softmin
+
+.. autoclass:: pina._src.adaptive_function.adaptive_softmin.AdaptiveSoftmin
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/adaptive_function/adaptive_tanh.rst b/docs/source/_rst/adaptive_function/adaptive_tanh.rst
new file mode 100644
index 000000000..dbd9e4313
--- /dev/null
+++ b/docs/source/_rst/adaptive_function/adaptive_tanh.rst
@@ -0,0 +1,9 @@
+Adaptive Tanh
+=============================
+.. currentmodule:: pina.adaptive_function.adaptive_tanh
+
+.. automodule:: pina._src.adaptive_function.adaptive_tanh
+
+.. autoclass:: pina._src.adaptive_function.adaptive_tanh.AdaptiveTanh
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/adaptive_function/base_adaptive_function.rst b/docs/source/_rst/adaptive_function/base_adaptive_function.rst
new file mode 100644
index 000000000..6b1e6cee7
--- /dev/null
+++ b/docs/source/_rst/adaptive_function/base_adaptive_function.rst
@@ -0,0 +1,9 @@
+Base Adaptive Function
+=============================
+.. currentmodule:: pina.adaptive_function.base_adaptive_function
+
+.. automodule:: pina._src.adaptive_function.base_adaptive_function
+
+.. autoclass:: pina._src.adaptive_function.base_adaptive_function.BaseAdaptiveFunction
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/callback/optim/switch_optimizer.rst b/docs/source/_rst/callback/optim/switch_optimizer.rst
index 635e79a18..13b7db7ad 100644
--- a/docs/source/_rst/callback/optim/switch_optimizer.rst
+++ b/docs/source/_rst/callback/optim/switch_optimizer.rst
@@ -2,6 +2,8 @@ Switch Optimizer
=====================
.. currentmodule:: pina.callback.optim.switch_optimizer
+.. automodule:: pina._src.callback.optim.switch_optimizer
+ :show-inheritance:
.. autoclass:: SwitchOptimizer
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/callback/optim/switch_scheduler.rst b/docs/source/_rst/callback/optim/switch_scheduler.rst
index 3176904da..42d5e6be0 100644
--- a/docs/source/_rst/callback/optim/switch_scheduler.rst
+++ b/docs/source/_rst/callback/optim/switch_scheduler.rst
@@ -2,6 +2,8 @@ Switch Scheduler
=====================
.. currentmodule:: pina.callback.optim.switch_scheduler
+.. automodule:: pina._src.callback.optim.switch_scheduler
+ :show-inheritance:
.. autoclass:: SwitchScheduler
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/callback/processing/data_normalizer.rst b/docs/source/_rst/callback/processing/data_normalizer.rst
new file mode 100644
index 000000000..358d2f472
--- /dev/null
+++ b/docs/source/_rst/callback/processing/data_normalizer.rst
@@ -0,0 +1,9 @@
+Data Normalizer
+=======================
+.. currentmodule:: pina.callback.processing.data_normalizer
+
+.. automodule:: pina._src.callback.processing.data_normalizer
+
+.. autoclass:: pina._src.callback.processing.data_normalizer.DataNormalizer
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/callback/processing/metric_tracker.rst b/docs/source/_rst/callback/processing/metric_tracker.rst
index f21cc7730..22d7cc229 100644
--- a/docs/source/_rst/callback/processing/metric_tracker.rst
+++ b/docs/source/_rst/callback/processing/metric_tracker.rst
@@ -2,6 +2,9 @@ Metric Tracker
==================
.. currentmodule:: pina.callback.processing.metric_tracker
-.. autoclass:: MetricTracker
+.. automodule:: pina._src.callback.processing.metric_tracker
+
+.. autoclass:: pina._src.callback.processing.metric_tracker.MetricTracker
:members:
- :show-inheritance:
\ No newline at end of file
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/callback/processing/normalizer_data_callback.rst b/docs/source/_rst/callback/processing/normalizer_data_callback.rst
deleted file mode 100644
index a44f0c402..000000000
--- a/docs/source/_rst/callback/processing/normalizer_data_callback.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-Normalizer Data
-=======================
-
-.. currentmodule:: pina.callback.processing.normalizer_data_callback
-.. autoclass:: NormalizerDataCallback
- :members:
- :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/callback/processing/pina_progress_bar.rst b/docs/source/_rst/callback/processing/pina_progress_bar.rst
index 1d42ad120..9c64678eb 100644
--- a/docs/source/_rst/callback/processing/pina_progress_bar.rst
+++ b/docs/source/_rst/callback/processing/pina_progress_bar.rst
@@ -2,6 +2,8 @@ PINA Progress Bar
==================
.. currentmodule:: pina.callback.processing.pina_progress_bar
-.. autoclass:: PINAProgressBar
+.. automodule:: pina._src.callback.processing.pina_progress_bar
+
+.. autoclass:: pina._src.callback.processing.pina_progress_bar.PINAProgressBar
:members:
- :show-inheritance:
\ No newline at end of file
+ :show-inheritance:
diff --git a/docs/source/_rst/callback/refinement/base_refinement.rst b/docs/source/_rst/callback/refinement/base_refinement.rst
new file mode 100644
index 000000000..5f8eaf218
--- /dev/null
+++ b/docs/source/_rst/callback/refinement/base_refinement.rst
@@ -0,0 +1,7 @@
+Base Refinement
+=======================
+
+.. currentmodule:: pina.callback.refinement.base_refinement
+.. autoclass:: pina._src.callback.refinement.base_refinement.BaseRefinement
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/callback/refinement/r3_refinement.rst b/docs/source/_rst/callback/refinement/r3_refinement.rst
index eb3bfebf2..0d787c840 100644
--- a/docs/source/_rst/callback/refinement/r3_refinement.rst
+++ b/docs/source/_rst/callback/refinement/r3_refinement.rst
@@ -1,7 +1,7 @@
-Refinments callbacks
+R3 Refinement
=======================
-.. currentmodule:: pina.callback.refinement
-.. autoclass:: R3Refinement
+.. currentmodule:: pina.callback.refinement.r3_refinement
+.. autoclass:: pina._src.callback.refinement.r3_refinement.R3Refinement
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/callback/refinement/refinement_interface.rst b/docs/source/_rst/callback/refinement/refinement_interface.rst
index 5e02f2dc3..1af845800 100644
--- a/docs/source/_rst/callback/refinement/refinement_interface.rst
+++ b/docs/source/_rst/callback/refinement/refinement_interface.rst
@@ -1,7 +1,7 @@
Refinement Interface
=======================
-.. currentmodule:: pina.callback.refinement
-.. autoclass:: RefinementInterface
+.. currentmodule:: pina.callback.refinement.refinement_interface
+.. autoclass:: pina._src.callback.refinement.refinement_interface.RefinementInterface
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/condition/base_condition.rst b/docs/source/_rst/condition/base_condition.rst
new file mode 100644
index 000000000..2ba4113bd
--- /dev/null
+++ b/docs/source/_rst/condition/base_condition.rst
@@ -0,0 +1,9 @@
+Base Condition
+================
+.. currentmodule:: pina.condition.base_condition
+
+.. automodule:: pina._src.condition.base_condition
+
+.. autoclass:: pina._src.condition.base_condition.BaseCondition
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/condition/condition.rst b/docs/source/_rst/condition/condition.rst
index 51edfafff..0f8070506 100644
--- a/docs/source/_rst/condition/condition.rst
+++ b/docs/source/_rst/condition/condition.rst
@@ -1,7 +1,9 @@
-Conditions
+Condition
=============
.. currentmodule:: pina.condition.condition
-.. autoclass:: Condition
+.. automodule:: pina._src.condition.condition
+
+.. autoclass:: pina._src.condition.condition.Condition
:members:
- :show-inheritance:
\ No newline at end of file
+ :show-inheritance:
diff --git a/docs/source/_rst/condition/condition_interface.rst b/docs/source/_rst/condition/condition_interface.rst
index 88459629b..a81de1afa 100644
--- a/docs/source/_rst/condition/condition_interface.rst
+++ b/docs/source/_rst/condition/condition_interface.rst
@@ -1,7 +1,9 @@
-ConditionInterface
+Condition Interface
======================
.. currentmodule:: pina.condition.condition_interface
-.. autoclass:: ConditionInterface
+.. automodule:: pina._src.condition.condition_interface
+
+.. autoclass:: pina._src.condition.condition_interface.ConditionInterface
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/condition/data_condition.rst b/docs/source/_rst/condition/data_condition.rst
index b7c322ea1..d614fbb7b 100644
--- a/docs/source/_rst/condition/data_condition.rst
+++ b/docs/source/_rst/condition/data_condition.rst
@@ -1,15 +1,9 @@
-Data Conditions
+Data Condition
==================
.. currentmodule:: pina.condition.data_condition
-.. autoclass:: DataCondition
- :members:
- :show-inheritance:
-
-.. autoclass:: GraphDataCondition
- :members:
- :show-inheritance:
+.. automodule:: pina._src.condition.data_condition
-.. autoclass:: TensorDataCondition
+.. autoclass:: pina._src.condition.data_condition.DataCondition
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/condition/domain_equation_condition.rst b/docs/source/_rst/condition/domain_equation_condition.rst
index 505c8b839..2c372f13f 100644
--- a/docs/source/_rst/condition/domain_equation_condition.rst
+++ b/docs/source/_rst/condition/domain_equation_condition.rst
@@ -2,6 +2,8 @@ Domain Equation Condition
===========================
.. currentmodule:: pina.condition.domain_equation_condition
-.. autoclass:: DomainEquationCondition
+.. automodule:: pina._src.condition.domain_equation_condition
+
+.. autoclass:: pina._src.condition.domain_equation_condition.DomainEquationCondition
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/condition/graph_time_series_condition.rst b/docs/source/_rst/condition/graph_time_series_condition.rst
new file mode 100644
index 000000000..6314980fb
--- /dev/null
+++ b/docs/source/_rst/condition/graph_time_series_condition.rst
@@ -0,0 +1,10 @@
+Graph Time Series Condition
+=============================
+
+.. currentmodule:: pina.condition.graph_time_series_condition
+
+.. automodule:: pina._src.condition.graph_time_series_condition
+
+.. autoclass:: pina._src.condition.graph_time_series_condition.GraphTimeSeriesCondition
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/condition/input_equation_condition.rst b/docs/source/_rst/condition/input_equation_condition.rst
index 4f5450e93..da0a48476 100644
--- a/docs/source/_rst/condition/input_equation_condition.rst
+++ b/docs/source/_rst/condition/input_equation_condition.rst
@@ -2,14 +2,8 @@ Input Equation Condition
===========================
.. currentmodule:: pina.condition.input_equation_condition
-.. autoclass:: InputEquationCondition
- :members:
- :show-inheritance:
-
-.. autoclass:: InputTensorEquationCondition
- :members:
- :show-inheritance:
+.. automodule:: pina._src.condition.input_equation_condition
-.. autoclass:: InputGraphEquationCondition
+.. autoclass:: pina._src.condition.input_equation_condition.InputEquationCondition
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/condition/input_target_condition.rst b/docs/source/_rst/condition/input_target_condition.rst
index 960b7d6f4..da8333714 100644
--- a/docs/source/_rst/condition/input_target_condition.rst
+++ b/docs/source/_rst/condition/input_target_condition.rst
@@ -2,22 +2,8 @@ Input Target Condition
===========================
.. currentmodule:: pina.condition.input_target_condition
-.. autoclass:: InputTargetCondition
- :members:
- :show-inheritance:
-
-.. autoclass:: TensorInputTensorTargetCondition
- :members:
- :show-inheritance:
+.. automodule:: pina._src.condition.input_target_condition
-.. autoclass:: TensorInputGraphTargetCondition
+.. autoclass:: pina._src.condition.input_target_condition.InputTargetCondition
:members:
:show-inheritance:
-
-.. autoclass:: GraphInputTensorTargetCondition
- :members:
- :show-inheritance:
-
-.. autoclass:: GraphInputGraphTargetCondition
- :members:
- :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/condition/time_series_condition.rst b/docs/source/_rst/condition/time_series_condition.rst
new file mode 100644
index 000000000..49a5f8795
--- /dev/null
+++ b/docs/source/_rst/condition/time_series_condition.rst
@@ -0,0 +1,10 @@
+Time Series Condition
+=======================
+
+.. currentmodule:: pina.condition.time_series_condition
+
+.. automodule:: pina._src.condition.time_series_condition
+
+.. autoclass:: pina._src.condition.time_series_condition.TimeSeriesCondition
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/data/aggregator.rst b/docs/source/_rst/data/aggregator.rst
new file mode 100644
index 000000000..738a57524
--- /dev/null
+++ b/docs/source/_rst/data/aggregator.rst
@@ -0,0 +1,9 @@
+Aggregator
+================
+.. currentmodule:: pina.data.aggregator
+
+.. automodule:: pina._src.data.aggregator
+
+.. autoclass:: pina._src.data.aggregator._Aggregator
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/data/condition_subset.rst b/docs/source/_rst/data/condition_subset.rst
new file mode 100644
index 000000000..84c032dc8
--- /dev/null
+++ b/docs/source/_rst/data/condition_subset.rst
@@ -0,0 +1,9 @@
+Condition Subset
+================
+.. currentmodule:: pina.data.condition_subset
+
+.. automodule:: pina._src.data.condition_subset
+
+.. autoclass:: pina._src.data.condition_subset._ConditionSubset
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/data/creator.rst b/docs/source/_rst/data/creator.rst
new file mode 100644
index 000000000..5d836292d
--- /dev/null
+++ b/docs/source/_rst/data/creator.rst
@@ -0,0 +1,9 @@
+Creator
+=======
+.. currentmodule:: pina.data.creator
+
+.. automodule:: pina._src.data.creator
+
+.. autoclass:: pina._src.data.creator._Creator
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/data/data_module.rst b/docs/source/_rst/data/data_module.rst
index b7ffb14e0..e31dae2b9 100644
--- a/docs/source/_rst/data/data_module.rst
+++ b/docs/source/_rst/data/data_module.rst
@@ -2,14 +2,6 @@ DataModule
======================
.. currentmodule:: pina.data.data_module
-.. autoclass:: Collator
+.. autoclass:: pina._src.data.data_module.DataModule
:members:
:show-inheritance:
-
-.. autoclass:: PinaDataModule
- :members:
- :show-inheritance:
-
-.. autoclass:: PinaSampler
- :members:
- :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/data/dataset.rst b/docs/source/_rst/data/dataset.rst
deleted file mode 100644
index b49b41db1..000000000
--- a/docs/source/_rst/data/dataset.rst
+++ /dev/null
@@ -1,19 +0,0 @@
-Dataset
-======================
-.. currentmodule:: pina.data.dataset
-
-.. autoclass:: PinaDataset
- :members:
- :show-inheritance:
-
-.. autoclass:: PinaDatasetFactory
- :members:
- :show-inheritance:
-
-.. autoclass:: PinaGraphDataset
- :members:
- :show-inheritance:
-
-.. autoclass:: PinaTensorDataset
- :members:
- :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/data/manager/batch_manager.rst b/docs/source/_rst/data/manager/batch_manager.rst
new file mode 100644
index 000000000..5d7c36650
--- /dev/null
+++ b/docs/source/_rst/data/manager/batch_manager.rst
@@ -0,0 +1,9 @@
+Batch Manager
+======================
+.. currentmodule:: pina.data.manager.batch_manager
+
+.. automodule:: pina._src.data.manager.batch_manager
+
+.. autoclass:: pina._src.data.manager.batch_manager._BatchManager
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/data/manager/data_manager.rst b/docs/source/_rst/data/manager/data_manager.rst
new file mode 100644
index 000000000..9b32b8242
--- /dev/null
+++ b/docs/source/_rst/data/manager/data_manager.rst
@@ -0,0 +1,9 @@
+Data Manager
+======================
+.. currentmodule:: pina.data.manager.data_manager
+
+.. automodule:: pina._src.data.manager.data_manager
+
+.. autoclass:: pina._src.data.manager.data_manager._DataManager
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/data/manager/data_manager_interface.rst b/docs/source/_rst/data/manager/data_manager_interface.rst
new file mode 100644
index 000000000..e4a502abf
--- /dev/null
+++ b/docs/source/_rst/data/manager/data_manager_interface.rst
@@ -0,0 +1,9 @@
+Data Manager Interface
+=========================
+.. currentmodule:: pina.data.manager.data_manager_interface
+
+.. automodule:: pina._src.data.manager.data_manager_interface
+
+.. autoclass:: pina._src.data.manager.data_manager_interface._DataManagerInterface
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/data/manager/graph_data_manager.rst b/docs/source/_rst/data/manager/graph_data_manager.rst
new file mode 100644
index 000000000..bbbf23a52
--- /dev/null
+++ b/docs/source/_rst/data/manager/graph_data_manager.rst
@@ -0,0 +1,9 @@
+Graph Data Manager
+======================
+.. currentmodule:: pina.data.manager.graph_data_manager
+
+.. automodule:: pina._src.data.manager.graph_data_manager
+
+.. autoclass:: pina._src.data.manager.graph_data_manager._GraphDataManager
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/data/manager/tensor_data_manager.rst b/docs/source/_rst/data/manager/tensor_data_manager.rst
new file mode 100644
index 000000000..f8bb06028
--- /dev/null
+++ b/docs/source/_rst/data/manager/tensor_data_manager.rst
@@ -0,0 +1,9 @@
+Tensor Data Manager
+======================
+.. currentmodule:: pina.data.manager.tensor_data_manager
+
+.. automodule:: pina._src.data.manager.tensor_data_manager
+
+.. autoclass:: pina._src.data.manager.tensor_data_manager._TensorDataManager
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/data/single_batch_data_loader.rst b/docs/source/_rst/data/single_batch_data_loader.rst
new file mode 100644
index 000000000..7c1debb92
--- /dev/null
+++ b/docs/source/_rst/data/single_batch_data_loader.rst
@@ -0,0 +1,9 @@
+Single-Batch Data Loader
+===========================
+.. currentmodule:: pina.data.single_batch_data_loader
+
+.. automodule:: pina._src.data.single_batch_data_loader
+
+.. autoclass:: pina._src.data.single_batch_data_loader._SingleBatchDataLoader
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/domain/base_domain.rst b/docs/source/_rst/domain/base_domain.rst
index e6b9ce88c..3850ba4fa 100644
--- a/docs/source/_rst/domain/base_domain.rst
+++ b/docs/source/_rst/domain/base_domain.rst
@@ -2,8 +2,8 @@ BaseDomain
===========
.. currentmodule:: pina.domain.base_domain
-.. automodule:: pina.domain.base_domain
+.. automodule:: pina._src.domain.base_domain
-.. autoclass:: BaseDomain
+.. autoclass:: pina._src.domain.base_domain.BaseDomain
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/domain/base_operation.rst b/docs/source/_rst/domain/base_operation.rst
index cfa145f03..122048d81 100644
--- a/docs/source/_rst/domain/base_operation.rst
+++ b/docs/source/_rst/domain/base_operation.rst
@@ -2,8 +2,8 @@ BaseOperation
==============
.. currentmodule:: pina.domain.base_operation
-.. automodule:: pina.domain.base_operation
+.. automodule:: pina._src.domain.base_operation
-.. autoclass:: BaseOperation
+.. autoclass:: pina._src.domain.base_operation.BaseOperation
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/domain/cartesian_domain.rst b/docs/source/_rst/domain/cartesian_domain.rst
index 15491be8c..bc2afec03 100644
--- a/docs/source/_rst/domain/cartesian_domain.rst
+++ b/docs/source/_rst/domain/cartesian_domain.rst
@@ -2,9 +2,9 @@ CartesianDomain
======================
.. currentmodule:: pina.domain.cartesian_domain
-.. automodule:: pina.domain.cartesian_domain
+.. automodule:: pina._src.domain.cartesian_domain
-.. autoclass:: CartesianDomain
+.. autoclass:: pina._src.domain.cartesian_domain.CartesianDomain
:members:
:show-inheritance:
:noindex:
diff --git a/docs/source/_rst/domain/difference.rst b/docs/source/_rst/domain/difference.rst
index 0167c3062..91ffd4ec9 100644
--- a/docs/source/_rst/domain/difference.rst
+++ b/docs/source/_rst/domain/difference.rst
@@ -2,8 +2,8 @@ Difference
======================
.. currentmodule:: pina.domain.difference
-.. automodule:: pina.domain.difference
+.. automodule:: pina._src.domain.difference
-.. autoclass:: Difference
+.. autoclass:: pina._src.domain.difference.Difference
:members:
:show-inheritance:
diff --git a/docs/source/_rst/domain/domain_interface.rst b/docs/source/_rst/domain/domain_interface.rst
index 898896ba3..96594a23b 100644
--- a/docs/source/_rst/domain/domain_interface.rst
+++ b/docs/source/_rst/domain/domain_interface.rst
@@ -2,8 +2,8 @@ DomainInterface
================
.. currentmodule:: pina.domain.domain_interface
-.. automodule:: pina.domain.domain_interface
+.. automodule:: pina._src.domain.domain_interface
-.. autoclass:: DomainInterface
+.. autoclass:: pina._src.domain.domain_interface.DomainInterface
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/domain/ellipsoid_domain.rst b/docs/source/_rst/domain/ellipsoid_domain.rst
index 4a9799e29..2cbc5f7ec 100644
--- a/docs/source/_rst/domain/ellipsoid_domain.rst
+++ b/docs/source/_rst/domain/ellipsoid_domain.rst
@@ -2,9 +2,9 @@ EllipsoidDomain
======================
.. currentmodule:: pina.domain.ellipsoid_domain
-.. automodule:: pina.domain.ellipsoid_domain
+.. automodule:: pina._src.domain.ellipsoid_domain
-.. autoclass:: EllipsoidDomain
+.. autoclass:: pina._src.domain.ellipsoid_domain.EllipsoidDomain
:members:
:show-inheritance:
:noindex:
diff --git a/docs/source/_rst/domain/exclusion.rst b/docs/source/_rst/domain/exclusion.rst
index f624122ae..040b48416 100644
--- a/docs/source/_rst/domain/exclusion.rst
+++ b/docs/source/_rst/domain/exclusion.rst
@@ -2,8 +2,8 @@ Exclusion
======================
.. currentmodule:: pina.domain.exclusion
-.. automodule:: pina.domain.exclusion
+.. automodule:: pina._src.domain.exclusion
-.. autoclass:: Exclusion
+.. autoclass:: pina._src.domain.exclusion.Exclusion
:members:
:show-inheritance:
diff --git a/docs/source/_rst/domain/intersection.rst b/docs/source/_rst/domain/intersection.rst
index fade1d042..666fe0f00 100644
--- a/docs/source/_rst/domain/intersection.rst
+++ b/docs/source/_rst/domain/intersection.rst
@@ -2,8 +2,8 @@ Intersection
======================
.. currentmodule:: pina.domain.intersection
-.. automodule:: pina.domain.intersection
+.. automodule:: pina._src.domain.intersection
-.. autoclass:: Intersection
+.. autoclass:: pina._src.domain.intersection.Intersection
:members:
:show-inheritance:
diff --git a/docs/source/_rst/domain/operation_interface.rst b/docs/source/_rst/domain/operation_interface.rst
index 0acd393dc..42e92fbe8 100644
--- a/docs/source/_rst/domain/operation_interface.rst
+++ b/docs/source/_rst/domain/operation_interface.rst
@@ -2,8 +2,8 @@ OperationInterface
======================
.. currentmodule:: pina.domain.operation_interface
-.. automodule:: pina.domain.operation_interface
+.. automodule:: pina._src.domain.operation_interface
-.. autoclass:: OperationInterface
+.. autoclass:: pina._src.domain.operation_interface.OperationInterface
:members:
:show-inheritance:
diff --git a/docs/source/_rst/domain/simplex_domain.rst b/docs/source/_rst/domain/simplex_domain.rst
index 5f1d31c9b..0aba5f912 100644
--- a/docs/source/_rst/domain/simplex_domain.rst
+++ b/docs/source/_rst/domain/simplex_domain.rst
@@ -2,9 +2,9 @@ SimplexDomain
======================
.. currentmodule:: pina.domain.simplex_domain
-.. automodule:: pina.domain.simplex_domain
+.. automodule:: pina._src.domain.simplex_domain
-.. autoclass:: SimplexDomain
+.. autoclass:: pina._src.domain.simplex_domain.SimplexDomain
:members:
:show-inheritance:
:noindex:
diff --git a/docs/source/_rst/domain/union.rst b/docs/source/_rst/domain/union.rst
index 614bb351c..fc5ff92a9 100644
--- a/docs/source/_rst/domain/union.rst
+++ b/docs/source/_rst/domain/union.rst
@@ -2,8 +2,8 @@ Union
======================
.. currentmodule:: pina.domain.union
-.. automodule:: pina.domain.union
+.. automodule:: pina._src.domain.union
-.. autoclass:: Union
+.. autoclass:: pina._src.domain.union.Union
:members:
:show-inheritance:
diff --git a/docs/source/_rst/equation/base_equation.rst b/docs/source/_rst/equation/base_equation.rst
new file mode 100644
index 000000000..5bb98901f
--- /dev/null
+++ b/docs/source/_rst/equation/base_equation.rst
@@ -0,0 +1,7 @@
+Base Equation
+====================
+
+.. currentmodule:: pina.equation.base_equation
+.. autoclass:: pina._src.equation.base_equation.BaseEquation
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/equation/equation.rst b/docs/source/_rst/equation/equation.rst
index 33e19c957..edb350090 100644
--- a/docs/source/_rst/equation/equation.rst
+++ b/docs/source/_rst/equation/equation.rst
@@ -2,6 +2,6 @@ Equation
==========
.. currentmodule:: pina.equation.equation
-.. autoclass:: Equation
+.. autoclass:: pina._src.equation.equation.Equation
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/equation/equation_factory.rst b/docs/source/_rst/equation/equation_factory.rst
deleted file mode 100644
index 86390c6bd..000000000
--- a/docs/source/_rst/equation/equation_factory.rst
+++ /dev/null
@@ -1,43 +0,0 @@
-Equation Factory
-==================
-
-.. currentmodule:: pina.equation.equation_factory
-.. autoclass:: FixedValue
- :members:
- :show-inheritance:
-
-.. autoclass:: FixedGradient
- :members:
- :show-inheritance:
-
-.. autoclass:: FixedFlux
- :members:
- :show-inheritance:
-
-.. autoclass:: FixedLaplacian
- :members:
- :show-inheritance:
-
-.. autoclass:: Laplace
- :members:
- :show-inheritance:
-
-.. autoclass:: Advection
- :members:
- :show-inheritance:
-
-.. autoclass:: AllenCahn
- :members:
- :show-inheritance:
-
-.. autoclass:: DiffusionReaction
- :members:
- :show-inheritance:
-
-.. autoclass:: Helmholtz
- :members:
- :show-inheritance:
-
-.. autoclass:: Poisson
- :members:
- :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/equation/equation_interface.rst b/docs/source/_rst/equation/equation_interface.rst
index cde7b0012..f16502831 100644
--- a/docs/source/_rst/equation/equation_interface.rst
+++ b/docs/source/_rst/equation/equation_interface.rst
@@ -2,6 +2,6 @@ Equation Interface
====================
.. currentmodule:: pina.equation.equation_interface
-.. autoclass:: EquationInterface
+.. autoclass:: pina._src.equation.equation_interface.EquationInterface
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/equation/system_equation.rst b/docs/source/_rst/equation/system_equation.rst
index 33c931cd9..88d1554f8 100644
--- a/docs/source/_rst/equation/system_equation.rst
+++ b/docs/source/_rst/equation/system_equation.rst
@@ -2,6 +2,6 @@ System Equation
=================
.. currentmodule:: pina.equation.system_equation
-.. autoclass:: SystemEquation
+.. autoclass:: pina._src.equation.system_equation.SystemEquation
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/equation/zoo/acoustic_wave_equation.rst b/docs/source/_rst/equation/zoo/acoustic_wave_equation.rst
new file mode 100644
index 000000000..5bc19d920
--- /dev/null
+++ b/docs/source/_rst/equation/zoo/acoustic_wave_equation.rst
@@ -0,0 +1,7 @@
+AcousticWaveEquation
+=====================
+.. currentmodule:: pina.equation.zoo.acoustic_wave_equation
+
+.. automodule:: pina._src.equation.zoo.acoustic_wave_equation
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/equation/zoo/advection_equation.rst b/docs/source/_rst/equation/zoo/advection_equation.rst
new file mode 100644
index 000000000..4386b3a3d
--- /dev/null
+++ b/docs/source/_rst/equation/zoo/advection_equation.rst
@@ -0,0 +1,7 @@
+Advection Equation
+=====================
+.. currentmodule:: pina.equation.zoo.advection_equation
+
+.. automodule:: pina._src.equation.zoo.advection_equation
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/equation/zoo/allen_cahn_equation.rst b/docs/source/_rst/equation/zoo/allen_cahn_equation.rst
new file mode 100644
index 000000000..fff220811
--- /dev/null
+++ b/docs/source/_rst/equation/zoo/allen_cahn_equation.rst
@@ -0,0 +1,7 @@
+Allen Cahn Equation
+=====================
+.. currentmodule:: pina.equation.zoo.allen_cahn_equation
+
+.. automodule:: pina._src.equation.zoo.allen_cahn_equation
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/equation/zoo/burgers_equation.rst b/docs/source/_rst/equation/zoo/burgers_equation.rst
new file mode 100644
index 000000000..8f478621f
--- /dev/null
+++ b/docs/source/_rst/equation/zoo/burgers_equation.rst
@@ -0,0 +1,7 @@
+Burgers' Equation
+====================
+.. currentmodule:: pina.equation.zoo.burgers_equation
+
+.. automodule:: pina._src.equation.zoo.burgers_equation
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/equation/zoo/diffusion_reaction_equation.rst b/docs/source/_rst/equation/zoo/diffusion_reaction_equation.rst
new file mode 100644
index 000000000..d45143074
--- /dev/null
+++ b/docs/source/_rst/equation/zoo/diffusion_reaction_equation.rst
@@ -0,0 +1,7 @@
+Diffusion Reaction Equation
+==============================
+.. currentmodule:: pina.equation.zoo.diffusion_reaction_equation
+
+.. automodule:: pina._src.equation.zoo.diffusion_reaction_equation
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/equation/zoo/fixed_flux.rst b/docs/source/_rst/equation/zoo/fixed_flux.rst
new file mode 100644
index 000000000..9b81db4b2
--- /dev/null
+++ b/docs/source/_rst/equation/zoo/fixed_flux.rst
@@ -0,0 +1,7 @@
+Fixed Flux
+=====================
+.. currentmodule:: pina.equation.zoo.fixed_flux
+
+.. automodule:: pina._src.equation.zoo.fixed_flux
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/equation/zoo/fixed_gradient.rst b/docs/source/_rst/equation/zoo/fixed_gradient.rst
new file mode 100644
index 000000000..f8da5dea8
--- /dev/null
+++ b/docs/source/_rst/equation/zoo/fixed_gradient.rst
@@ -0,0 +1,7 @@
+Fixed Gradient
+=====================
+.. currentmodule:: pina.equation.zoo.fixed_gradient
+
+.. automodule:: pina._src.equation.zoo.fixed_gradient
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/equation/zoo/fixed_laplacian.rst b/docs/source/_rst/equation/zoo/fixed_laplacian.rst
new file mode 100644
index 000000000..3123918a6
--- /dev/null
+++ b/docs/source/_rst/equation/zoo/fixed_laplacian.rst
@@ -0,0 +1,7 @@
+Fixed Laplacian
+=====================
+.. currentmodule:: pina.equation.zoo.fixed_laplacian
+
+.. automodule:: pina._src.equation.zoo.fixed_laplacian
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/equation/zoo/fixed_value.rst b/docs/source/_rst/equation/zoo/fixed_value.rst
new file mode 100644
index 000000000..29eaa0521
--- /dev/null
+++ b/docs/source/_rst/equation/zoo/fixed_value.rst
@@ -0,0 +1,7 @@
+Fixed Value
+=====================
+.. currentmodule:: pina.equation.zoo.fixed_value
+
+.. automodule:: pina._src.equation.zoo.fixed_value
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/equation/zoo/helmholtz_equation.rst b/docs/source/_rst/equation/zoo/helmholtz_equation.rst
new file mode 100644
index 000000000..7728b60ed
--- /dev/null
+++ b/docs/source/_rst/equation/zoo/helmholtz_equation.rst
@@ -0,0 +1,9 @@
+Helmholtz Equation
+=====================
+.. currentmodule:: pina.equation.zoo.helmholtz_equation
+
+.. automodule:: pina._src.equation.zoo.helmholtz_equation
+
+.. autoclass:: pina._src.equation.zoo.helmholtz_equation.HelmholtzEquation
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/equation/zoo/poisson_equation.rst b/docs/source/_rst/equation/zoo/poisson_equation.rst
new file mode 100644
index 000000000..f23796450
--- /dev/null
+++ b/docs/source/_rst/equation/zoo/poisson_equation.rst
@@ -0,0 +1,9 @@
+Poisson Equation
+=====================
+.. currentmodule:: pina.equation.zoo.poisson_equation
+
+.. automodule:: pina._src.equation.zoo.poisson_equation
+
+.. autoclass:: pina._src.equation.zoo.poisson_equation.PoissonEquation
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/graph/graph.rst b/docs/source/_rst/graph/graph.rst
index 1921f83e0..58180f50f 100644
--- a/docs/source/_rst/graph/graph.rst
+++ b/docs/source/_rst/graph/graph.rst
@@ -3,7 +3,7 @@ Graph
.. currentmodule:: pina.graph
-.. autoclass:: Graph
+.. autoclass:: pina._src.core.graph.Graph
:members:
:private-members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/graph/graph_builder.rst b/docs/source/_rst/graph/graph_builder.rst
index 2508aecb7..f576fe7c7 100644
--- a/docs/source/_rst/graph/graph_builder.rst
+++ b/docs/source/_rst/graph/graph_builder.rst
@@ -3,7 +3,7 @@ GraphBuilder
.. currentmodule:: pina.graph
-.. autoclass:: GraphBuilder
+.. autoclass:: pina._src.core.graph.GraphBuilder
:members:
:private-members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/graph/knn_graph.rst b/docs/source/_rst/graph/knn_graph.rst
index 8ef0b190b..e31a004ab 100644
--- a/docs/source/_rst/graph/knn_graph.rst
+++ b/docs/source/_rst/graph/knn_graph.rst
@@ -3,7 +3,7 @@ KNNGraph
.. currentmodule:: pina.graph
-.. autoclass:: KNNGraph
+.. autoclass:: pina._src.core.graph.KNNGraph
:members:
:private-members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/graph/label_batch.rst b/docs/source/_rst/graph/label_batch.rst
index 7cd4d2684..5a68bde60 100644
--- a/docs/source/_rst/graph/label_batch.rst
+++ b/docs/source/_rst/graph/label_batch.rst
@@ -3,7 +3,7 @@ LabelBatch
.. currentmodule:: pina.graph
-.. autoclass:: LabelBatch
+.. autoclass:: pina._src.core.graph.LabelBatch
:members:
:private-members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/graph/radius_graph.rst b/docs/source/_rst/graph/radius_graph.rst
index 7414d2dc1..9db9fb174 100644
--- a/docs/source/_rst/graph/radius_graph.rst
+++ b/docs/source/_rst/graph/radius_graph.rst
@@ -3,7 +3,7 @@ RadiusGraph
.. currentmodule:: pina.graph
-.. autoclass:: RadiusGraph
+.. autoclass:: pina._src.core.graph.RadiusGraph
:members:
:private-members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/label_tensor.rst b/docs/source/_rst/label_tensor.rst
index 9eb227369..1b750ad97 100644
--- a/docs/source/_rst/label_tensor.rst
+++ b/docs/source/_rst/label_tensor.rst
@@ -2,8 +2,11 @@ LabelTensor
===========
.. currentmodule:: pina.label_tensor
+.. automodule:: pina._src.core.label_tensor
+ :no-members:
-.. autoclass:: LabelTensor
+
+.. autoclass:: pina._src.core.label_tensor.LabelTensor
:members:
:private-members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/loss/base_dual_loss.rst b/docs/source/_rst/loss/base_dual_loss.rst
new file mode 100644
index 000000000..8037f894b
--- /dev/null
+++ b/docs/source/_rst/loss/base_dual_loss.rst
@@ -0,0 +1,9 @@
+Base Dual Loss
+================
+.. currentmodule:: pina.loss.base_dual_loss
+
+.. automodule:: pina._src.loss.base_dual_loss
+
+.. autoclass:: pina._src.loss.base_dual_loss.BaseDualLoss
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/loss/dual_loss_interface.rst b/docs/source/_rst/loss/dual_loss_interface.rst
new file mode 100644
index 000000000..a6a005914
--- /dev/null
+++ b/docs/source/_rst/loss/dual_loss_interface.rst
@@ -0,0 +1,9 @@
+Dual Loss Interface
+===================
+.. currentmodule:: pina.loss.dual_loss_interface
+
+.. automodule:: pina._src.loss.dual_loss_interface
+
+.. autoclass:: pina._src.loss.dual_loss_interface.DualLossInterface
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/loss/linear_weighting.rst b/docs/source/_rst/loss/linear_weighting.rst
deleted file mode 100644
index 16e6232d0..000000000
--- a/docs/source/_rst/loss/linear_weighting.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-LinearWeighting
-=============================
-.. currentmodule:: pina.loss.linear_weighting
-
-.. automodule:: pina.loss.linear_weighting
-
-.. autoclass:: LinearWeighting
- :members:
- :show-inheritance:
diff --git a/docs/source/_rst/loss/loss_interface.rst b/docs/source/_rst/loss/loss_interface.rst
deleted file mode 100644
index 8ff78c01e..000000000
--- a/docs/source/_rst/loss/loss_interface.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-LossInterface
-===============
-.. currentmodule:: pina.loss.loss_interface
-
-.. automodule:: pina.loss.loss_interface
-
-.. autoclass:: LossInterface
- :members:
- :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/loss/lp_loss.rst b/docs/source/_rst/loss/lp_loss.rst
new file mode 100644
index 000000000..4924d3445
--- /dev/null
+++ b/docs/source/_rst/loss/lp_loss.rst
@@ -0,0 +1,10 @@
+Lp Loss
+===============
+.. currentmodule:: pina.loss.lp_loss
+
+.. automodule:: pina._src.loss.lp_loss
+ :no-members:
+
+.. autoclass:: pina._src.loss.lp_loss.LpLoss
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/loss/lploss.rst b/docs/source/_rst/loss/lploss.rst
deleted file mode 100644
index 37dfdfe3c..000000000
--- a/docs/source/_rst/loss/lploss.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-LpLoss
-===============
-.. currentmodule:: pina.loss.lp_loss
-
-.. autoclass:: LpLoss
- :members:
- :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/loss/ntk_weighting.rst b/docs/source/_rst/loss/ntk_weighting.rst
deleted file mode 100644
index 6d9d8816d..000000000
--- a/docs/source/_rst/loss/ntk_weighting.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-NeuralTangentKernelWeighting
-=============================
-.. currentmodule:: pina.loss.ntk_weighting
-
-.. automodule:: pina.loss.ntk_weighting
-
-.. autoclass:: NeuralTangentKernelWeighting
- :members:
- :show-inheritance:
diff --git a/docs/source/_rst/loss/power_loss.rst b/docs/source/_rst/loss/power_loss.rst
new file mode 100644
index 000000000..a0258c20f
--- /dev/null
+++ b/docs/source/_rst/loss/power_loss.rst
@@ -0,0 +1,9 @@
+Power Loss
+====================
+.. currentmodule:: pina.loss.power_loss
+
+.. automodule:: pina._src.loss.power_loss
+
+.. autoclass:: pina._src.loss.power_loss.PowerLoss
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/loss/powerloss.rst b/docs/source/_rst/loss/powerloss.rst
deleted file mode 100644
index e4dee43b8..000000000
--- a/docs/source/_rst/loss/powerloss.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-PowerLoss
-====================
-.. currentmodule:: pina.loss.power_loss
-
-.. autoclass:: PowerLoss
- :members:
- :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/loss/scalar_weighting.rst b/docs/source/_rst/loss/scalar_weighting.rst
deleted file mode 100644
index 5ee82a785..000000000
--- a/docs/source/_rst/loss/scalar_weighting.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-ScalarWeighting
-===================
-.. currentmodule:: pina.loss.scalar_weighting
-
-.. automodule:: pina.loss.scalar_weighting
-
-.. autoclass:: ScalarWeighting
- :members:
- :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/loss/self_adaptive_weighting.rst b/docs/source/_rst/loss/self_adaptive_weighting.rst
deleted file mode 100644
index cd1daed1f..000000000
--- a/docs/source/_rst/loss/self_adaptive_weighting.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-SelfAdaptiveWeighting
-=============================
-.. currentmodule:: pina.loss.self_adaptive_weighting
-
-.. automodule:: pina.loss.self_adaptive_weighting
-
-.. autoclass:: SelfAdaptiveWeighting
- :members:
- :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/loss/sinkhorn_loss.rst b/docs/source/_rst/loss/sinkhorn_loss.rst
new file mode 100644
index 000000000..17aa370ad
--- /dev/null
+++ b/docs/source/_rst/loss/sinkhorn_loss.rst
@@ -0,0 +1,11 @@
+Sinkhorn Loss
+===============
+
+.. currentmodule:: pina.loss.sinkhorn_loss
+
+.. automodule:: pina._src.loss.sinkhorn_loss
+ :no-members:
+
+.. autoclass:: pina._src.loss.sinkhorn_loss.SinkhornLoss
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/loss/weighting_interface.rst b/docs/source/_rst/loss/weighting_interface.rst
deleted file mode 100644
index 2b0fa1bdc..000000000
--- a/docs/source/_rst/loss/weighting_interface.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-WeightingInterface
-===================
-.. currentmodule:: pina.loss.weighting_interface
-
-.. automodule:: pina.loss.weighting_interface
-
-.. autoclass:: WeightingInterface
- :members:
- :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/average_neural_operator.rst b/docs/source/_rst/model/average_neural_operator.rst
index 02211e9a8..a54107620 100644
--- a/docs/source/_rst/model/average_neural_operator.rst
+++ b/docs/source/_rst/model/average_neural_operator.rst
@@ -2,6 +2,6 @@ Averaging Neural Operator
==============================
.. currentmodule:: pina.model.average_neural_operator
-.. autoclass:: AveragingNeuralOperator
+.. autoclass:: pina._src.model.average_neural_operator.AveragingNeuralOperator
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/block/average_neural_operator_block.rst b/docs/source/_rst/model/block/average_neural_operator_block.rst
index 0072ec9d0..1e38fc215 100644
--- a/docs/source/_rst/model/block/average_neural_operator_block.rst
+++ b/docs/source/_rst/model/block/average_neural_operator_block.rst
@@ -2,7 +2,7 @@ Averaging Neural Operator Block
==================================
.. currentmodule:: pina.model.block.average_neural_operator_block
-.. autoclass:: AVNOBlock
+.. autoclass:: pina._src.model.block.average_neural_operator_block.AVNOBlock
:members:
:show-inheritance:
:noindex:
diff --git a/docs/source/_rst/model/block/convolution.rst b/docs/source/_rst/model/block/convolution.rst
index 4033d5d56..bd0d32e71 100644
--- a/docs/source/_rst/model/block/convolution.rst
+++ b/docs/source/_rst/model/block/convolution.rst
@@ -2,7 +2,7 @@ Continuous Convolution Block
===============================
.. currentmodule:: pina.model.block.convolution_2d
-.. autoclass:: ContinuousConvBlock
+.. autoclass:: pina._src.model.block.convolution_2d.ContinuousConvBlock
:members:
:show-inheritance:
:noindex:
diff --git a/docs/source/_rst/model/block/convolution_interface.rst b/docs/source/_rst/model/block/convolution_interface.rst
index f8e61c16c..c6708ca94 100644
--- a/docs/source/_rst/model/block/convolution_interface.rst
+++ b/docs/source/_rst/model/block/convolution_interface.rst
@@ -2,7 +2,7 @@ Continuous Convolution Interface
==================================
.. currentmodule:: pina.model.block.convolution
-.. autoclass:: BaseContinuousConv
+.. autoclass:: pina._src.model.block.convolution.BaseContinuousConv
:members:
:show-inheritance:
:noindex:
diff --git a/docs/source/_rst/model/block/enhanced_linear.rst b/docs/source/_rst/model/block/enhanced_linear.rst
index d08cf79bf..92e8d5581 100644
--- a/docs/source/_rst/model/block/enhanced_linear.rst
+++ b/docs/source/_rst/model/block/enhanced_linear.rst
@@ -2,7 +2,7 @@ EnhancedLinear Block
=====================
.. currentmodule:: pina.model.block.residual
-.. autoclass:: EnhancedLinear
+.. autoclass:: pina._src.model.block.residual.EnhancedLinear
:members:
:show-inheritance:
:noindex:
\ No newline at end of file
diff --git a/docs/source/_rst/model/block/fourier_block.rst b/docs/source/_rst/model/block/fourier_block.rst
index c0fff4deb..9b601bb3d 100644
--- a/docs/source/_rst/model/block/fourier_block.rst
+++ b/docs/source/_rst/model/block/fourier_block.rst
@@ -3,14 +3,14 @@ Fourier Neural Operator Block
.. currentmodule:: pina.model.block.fourier_block
-.. autoclass:: FourierBlock1D
+.. autoclass:: pina._src.model.block.fourier_block.FourierBlock1D
:members:
:show-inheritance:
-.. autoclass:: FourierBlock2D
+.. autoclass:: pina._src.model.block.fourier_block.FourierBlock2D
:members:
:show-inheritance:
-.. autoclass:: FourierBlock3D
+.. autoclass:: pina._src.model.block.fourier_block.FourierBlock3D
:members:
:show-inheritance:
diff --git a/docs/source/_rst/model/block/fourier_embedding.rst b/docs/source/_rst/model/block/fourier_embedding.rst
index 77eb3960c..48c8df41c 100644
--- a/docs/source/_rst/model/block/fourier_embedding.rst
+++ b/docs/source/_rst/model/block/fourier_embedding.rst
@@ -2,7 +2,7 @@ Fourier Feature Embedding
=======================================
.. currentmodule:: pina.model.block.embedding
-.. autoclass:: FourierFeatureEmbedding
+.. autoclass:: pina._src.model.block.embedding.FourierFeatureEmbedding
:members:
:show-inheritance:
diff --git a/docs/source/_rst/model/block/gno_block.rst b/docs/source/_rst/model/block/gno_block.rst
index 19a532bab..8ce3f2f30 100644
--- a/docs/source/_rst/model/block/gno_block.rst
+++ b/docs/source/_rst/model/block/gno_block.rst
@@ -2,7 +2,7 @@ Graph Neural Operator Block
===============================
.. currentmodule:: pina.model.block.gno_block
-.. autoclass:: GNOBlock
+.. autoclass:: pina._src.model.block.gno_block.GNOBlock
:members:
:show-inheritance:
:noindex:
diff --git a/docs/source/_rst/model/block/kan_block.rst b/docs/source/_rst/model/block/kan_block.rst
new file mode 100644
index 000000000..95ca239eb
--- /dev/null
+++ b/docs/source/_rst/model/block/kan_block.rst
@@ -0,0 +1,7 @@
+KANBlock
+=======================
+.. currentmodule:: pina.model.block.kan_block
+
+.. autoclass:: pina._src.model.block.kan_block.KANBlock
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/block/low_rank_block.rst b/docs/source/_rst/model/block/low_rank_block.rst
index 366068f79..83c7a11a0 100644
--- a/docs/source/_rst/model/block/low_rank_block.rst
+++ b/docs/source/_rst/model/block/low_rank_block.rst
@@ -2,7 +2,7 @@ Low Rank Neural Operator Block
=================================
.. currentmodule:: pina.model.block.low_rank_block
-.. autoclass:: LowRankBlock
+.. autoclass:: pina._src.model.block.low_rank_block.LowRankBlock
:members:
:show-inheritance:
:noindex:
diff --git a/docs/source/_rst/model/block/message_passing/deep_tensor_network_block.rst b/docs/source/_rst/model/block/message_passing/deep_tensor_network_block.rst
index 30121e5a6..51482496a 100644
--- a/docs/source/_rst/model/block/message_passing/deep_tensor_network_block.rst
+++ b/docs/source/_rst/model/block/message_passing/deep_tensor_network_block.rst
@@ -2,7 +2,7 @@ Deep Tensor Network Block
==================================
.. currentmodule:: pina.model.block.message_passing.deep_tensor_network_block
-.. autoclass:: DeepTensorNetworkBlock
+.. autoclass:: pina._src.model.block.message_passing.deep_tensor_network_block.DeepTensorNetworkBlock
:members:
:show-inheritance:
:noindex:
diff --git a/docs/source/_rst/model/block/message_passing/en_equivariant_network_block.rst b/docs/source/_rst/model/block/message_passing/en_equivariant_network_block.rst
index e2755c665..09966ea0a 100644
--- a/docs/source/_rst/model/block/message_passing/en_equivariant_network_block.rst
+++ b/docs/source/_rst/model/block/message_passing/en_equivariant_network_block.rst
@@ -2,7 +2,7 @@ E(n) Equivariant Network Block
==================================
.. currentmodule:: pina.model.block.message_passing.en_equivariant_network_block
-.. autoclass:: EnEquivariantNetworkBlock
+.. autoclass:: pina._src.model.block.message_passing.en_equivariant_network_block.EnEquivariantNetworkBlock
:members:
:show-inheritance:
:noindex:
\ No newline at end of file
diff --git a/docs/source/_rst/model/block/message_passing/equivariant_graph_neural_operator_block.rst b/docs/source/_rst/model/block/message_passing/equivariant_graph_neural_operator_block.rst
index 8d047f84e..b61c4f430 100644
--- a/docs/source/_rst/model/block/message_passing/equivariant_graph_neural_operator_block.rst
+++ b/docs/source/_rst/model/block/message_passing/equivariant_graph_neural_operator_block.rst
@@ -2,6 +2,6 @@ EquivariantGraphNeuralOperatorBlock
=====================================
.. currentmodule:: pina.model.block.message_passing.equivariant_graph_neural_operator_block
-.. autoclass:: EquivariantGraphNeuralOperatorBlock
+.. autoclass:: pina._src.model.block.message_passing.equivariant_graph_neural_operator_block.EquivariantGraphNeuralOperatorBlock
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/block/message_passing/interaction_network_block.rst b/docs/source/_rst/model/block/message_passing/interaction_network_block.rst
index ffac307e2..a4c86e562 100644
--- a/docs/source/_rst/model/block/message_passing/interaction_network_block.rst
+++ b/docs/source/_rst/model/block/message_passing/interaction_network_block.rst
@@ -2,7 +2,7 @@ Interaction Network Block
==================================
.. currentmodule:: pina.model.block.message_passing.interaction_network_block
-.. autoclass:: InteractionNetworkBlock
+.. autoclass:: pina._src.model.block.message_passing.interaction_network_block.InteractionNetworkBlock
:members:
:show-inheritance:
:noindex:
\ No newline at end of file
diff --git a/docs/source/_rst/model/block/message_passing/radial_field_network_block.rst b/docs/source/_rst/model/block/message_passing/radial_field_network_block.rst
index e05203f33..bb66ee770 100644
--- a/docs/source/_rst/model/block/message_passing/radial_field_network_block.rst
+++ b/docs/source/_rst/model/block/message_passing/radial_field_network_block.rst
@@ -2,7 +2,7 @@ Radial Field Network Block
==================================
.. currentmodule:: pina.model.block.message_passing.radial_field_network_block
-.. autoclass:: RadialFieldNetworkBlock
+.. autoclass:: pina._src.model.block.message_passing.radial_field_network_block.RadialFieldNetworkBlock
:members:
:show-inheritance:
:noindex:
\ No newline at end of file
diff --git a/docs/source/_rst/model/block/orthogonal.rst b/docs/source/_rst/model/block/orthogonal.rst
index 21d12998a..a9fc727fb 100644
--- a/docs/source/_rst/model/block/orthogonal.rst
+++ b/docs/source/_rst/model/block/orthogonal.rst
@@ -2,6 +2,6 @@ Orthogonal Block
======================
.. currentmodule:: pina.model.block.orthogonal
-.. autoclass:: OrthogonalBlock
+.. autoclass:: pina._src.model.block.orthogonal.OrthogonalBlock
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/block/pbc_embedding.rst b/docs/source/_rst/model/block/pbc_embedding.rst
index f469644af..e79ae9514 100644
--- a/docs/source/_rst/model/block/pbc_embedding.rst
+++ b/docs/source/_rst/model/block/pbc_embedding.rst
@@ -2,7 +2,7 @@ Periodic Boundary Condition Embedding
=======================================
.. currentmodule:: pina.model.block.embedding
-.. autoclass:: PeriodicBoundaryEmbedding
+.. autoclass:: pina._src.model.block.embedding.PeriodicBoundaryEmbedding
:members:
:show-inheritance:
diff --git a/docs/source/_rst/model/block/pirate_network_block.rst b/docs/source/_rst/model/block/pirate_network_block.rst
index 5d0428a68..f534d3cb0 100644
--- a/docs/source/_rst/model/block/pirate_network_block.rst
+++ b/docs/source/_rst/model/block/pirate_network_block.rst
@@ -2,7 +2,7 @@ PirateNet Block
=======================================
.. currentmodule:: pina.model.block.pirate_network_block
-.. autoclass:: PirateNetBlock
+.. autoclass:: pina._src.model.block.pirate_network_block.PirateNetBlock
:members:
:show-inheritance:
diff --git a/docs/source/_rst/model/block/pod_block.rst b/docs/source/_rst/model/block/pod_block.rst
index 4b66e2c97..98fadbb1e 100644
--- a/docs/source/_rst/model/block/pod_block.rst
+++ b/docs/source/_rst/model/block/pod_block.rst
@@ -2,6 +2,6 @@ Proper Orthogonal Decomposition Block
============================================
.. currentmodule:: pina.model.block.pod_block
-.. autoclass:: PODBlock
+.. autoclass:: pina._src.model.block.pod_block.PODBlock
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/block/rbf_block.rst b/docs/source/_rst/model/block/rbf_block.rst
index 545f14d08..b8997d21b 100644
--- a/docs/source/_rst/model/block/rbf_block.rst
+++ b/docs/source/_rst/model/block/rbf_block.rst
@@ -2,6 +2,6 @@ Radias Basis Function Block
=============================
.. currentmodule:: pina.model.block.rbf_block
-.. autoclass:: RBFBlock
+.. autoclass:: pina._src.model.block.rbf_block.RBFBlock
:members:
:show-inheritance:
diff --git a/docs/source/_rst/model/block/residual.rst b/docs/source/_rst/model/block/residual.rst
index 69741c74c..d0e478563 100644
--- a/docs/source/_rst/model/block/residual.rst
+++ b/docs/source/_rst/model/block/residual.rst
@@ -2,6 +2,6 @@ Residual Block
===================
.. currentmodule:: pina.model.block.residual
-.. autoclass:: ResidualBlock
+.. autoclass:: pina._src.model.block.residual.ResidualBlock
:members:
:show-inheritance:
diff --git a/docs/source/_rst/model/block/spectral.rst b/docs/source/_rst/model/block/spectral.rst
index 3c80f3dd8..1ee0e1d19 100644
--- a/docs/source/_rst/model/block/spectral.rst
+++ b/docs/source/_rst/model/block/spectral.rst
@@ -2,14 +2,14 @@ Spectral Convolution Block
============================
.. currentmodule:: pina.model.block.spectral
-.. autoclass:: SpectralConvBlock1D
+.. autoclass:: pina._src.model.block.spectral.SpectralConvBlock1D
:members:
:show-inheritance:
-.. autoclass:: SpectralConvBlock2D
+.. autoclass:: pina._src.model.block.spectral.SpectralConvBlock2D
:members:
:show-inheritance:
-.. autoclass:: SpectralConvBlock3D
+.. autoclass:: pina._src.model.block.spectral.SpectralConvBlock3D
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/deeponet.rst b/docs/source/_rst/model/deeponet.rst
index 0ca08242d..eef25dcae 100644
--- a/docs/source/_rst/model/deeponet.rst
+++ b/docs/source/_rst/model/deeponet.rst
@@ -2,6 +2,6 @@ DeepONet
===========
.. currentmodule:: pina.model.deeponet
-.. autoclass:: DeepONet
+.. autoclass:: pina._src.model.deeponet.DeepONet
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/equivariant_graph_neural_operator.rst b/docs/source/_rst/model/equivariant_graph_neural_operator.rst
index a11edcc00..e100f5c1e 100644
--- a/docs/source/_rst/model/equivariant_graph_neural_operator.rst
+++ b/docs/source/_rst/model/equivariant_graph_neural_operator.rst
@@ -2,6 +2,6 @@ EquivariantGraphNeuralOperator
=================================
.. currentmodule:: pina.model.equivariant_graph_neural_operator
-.. autoclass:: EquivariantGraphNeuralOperator
+.. autoclass:: pina._src.model.equivariant_graph_neural_operator.EquivariantGraphNeuralOperator
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/feed_forward.rst b/docs/source/_rst/model/feed_forward.rst
index 2dea8e550..be75ed70b 100644
--- a/docs/source/_rst/model/feed_forward.rst
+++ b/docs/source/_rst/model/feed_forward.rst
@@ -2,6 +2,6 @@ FeedForward
======================
.. currentmodule:: pina.model.feed_forward
-.. autoclass:: FeedForward
+.. autoclass:: pina._src.model.feed_forward.FeedForward
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/fourier_integral_kernel.rst b/docs/source/_rst/model/fourier_integral_kernel.rst
index b1fb484fe..dba63c429 100644
--- a/docs/source/_rst/model/fourier_integral_kernel.rst
+++ b/docs/source/_rst/model/fourier_integral_kernel.rst
@@ -2,6 +2,6 @@ FourierIntegralKernel
=========================
.. currentmodule:: pina.model.fourier_neural_operator
-.. autoclass:: FourierIntegralKernel
+.. autoclass:: pina._src.model.fourier_neural_operator.FourierIntegralKernel
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/fourier_neural_operator.rst b/docs/source/_rst/model/fourier_neural_operator.rst
index e77494fd0..14cb52667 100644
--- a/docs/source/_rst/model/fourier_neural_operator.rst
+++ b/docs/source/_rst/model/fourier_neural_operator.rst
@@ -2,6 +2,6 @@ FNO
===========
.. currentmodule:: pina.model.fourier_neural_operator
-.. autoclass:: FNO
+.. autoclass:: pina._src.model.fourier_neural_operator.FNO
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/graph_neural_operator.rst b/docs/source/_rst/model/graph_neural_operator.rst
index fbb8600e5..7f7b7ed6b 100644
--- a/docs/source/_rst/model/graph_neural_operator.rst
+++ b/docs/source/_rst/model/graph_neural_operator.rst
@@ -2,6 +2,6 @@ GraphNeuralOperator
=======================
.. currentmodule:: pina.model.graph_neural_operator
-.. autoclass:: GraphNeuralOperator
+.. autoclass:: pina._src.model.graph_neural_operator.GraphNeuralOperator
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/graph_neural_operator_integral_kernel.rst b/docs/source/_rst/model/graph_neural_operator_integral_kernel.rst
index cf15a31a5..45f78c366 100644
--- a/docs/source/_rst/model/graph_neural_operator_integral_kernel.rst
+++ b/docs/source/_rst/model/graph_neural_operator_integral_kernel.rst
@@ -2,6 +2,6 @@ GraphNeuralKernel
=======================
.. currentmodule:: pina.model.graph_neural_operator
-.. autoclass:: GraphNeuralKernel
+.. autoclass:: pina._src.model.graph_neural_operator.GraphNeuralKernel
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/kernel_neural_operator.rst b/docs/source/_rst/model/kernel_neural_operator.rst
index d693afac5..75a39b223 100644
--- a/docs/source/_rst/model/kernel_neural_operator.rst
+++ b/docs/source/_rst/model/kernel_neural_operator.rst
@@ -2,6 +2,6 @@ KernelNeuralOperator
=======================
.. currentmodule:: pina.model.kernel_neural_operator
-.. autoclass:: KernelNeuralOperator
+.. autoclass:: pina._src.model.kernel_neural_operator.KernelNeuralOperator
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/kolmogorov_arnold_network.rst b/docs/source/_rst/model/kolmogorov_arnold_network.rst
new file mode 100644
index 000000000..0211611f4
--- /dev/null
+++ b/docs/source/_rst/model/kolmogorov_arnold_network.rst
@@ -0,0 +1,7 @@
+KolmogorovArnoldNetwork
+===========================
+.. currentmodule:: pina.model.kolmogorov_arnold_network
+
+.. autoclass:: pina._src.model.kolmogorov_arnold_network.KolmogorovArnoldNetwork
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/low_rank_neural_operator.rst b/docs/source/_rst/model/low_rank_neural_operator.rst
index 22fe7cc93..e0362d144 100644
--- a/docs/source/_rst/model/low_rank_neural_operator.rst
+++ b/docs/source/_rst/model/low_rank_neural_operator.rst
@@ -2,6 +2,6 @@ Low Rank Neural Operator
==============================
.. currentmodule:: pina.model.low_rank_neural_operator
-.. autoclass:: LowRankNeuralOperator
+.. autoclass:: pina._src.model.low_rank_neural_operator.LowRankNeuralOperator
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/mionet.rst b/docs/source/_rst/model/mionet.rst
index fe6281710..1888d911e 100644
--- a/docs/source/_rst/model/mionet.rst
+++ b/docs/source/_rst/model/mionet.rst
@@ -2,6 +2,6 @@ MIONet
===========
.. currentmodule:: pina.model.deeponet
-.. autoclass:: MIONet
+.. autoclass:: pina._src.model.deeponet.MIONet
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/multi_feed_forward.rst b/docs/source/_rst/model/multi_feed_forward.rst
index aa79580ee..458173ced 100644
--- a/docs/source/_rst/model/multi_feed_forward.rst
+++ b/docs/source/_rst/model/multi_feed_forward.rst
@@ -2,6 +2,6 @@ MultiFeedForward
==================
.. currentmodule:: pina.model.multi_feed_forward
-.. autoclass:: MultiFeedForward
+.. autoclass:: pina._src.model.multi_feed_forward.MultiFeedForward
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/pirate_network.rst b/docs/source/_rst/model/pirate_network.rst
index 5b374c247..a60449a6c 100644
--- a/docs/source/_rst/model/pirate_network.rst
+++ b/docs/source/_rst/model/pirate_network.rst
@@ -2,6 +2,6 @@ PirateNet
=======================
.. currentmodule:: pina.model.pirate_network
-.. autoclass:: PirateNet
+.. autoclass:: pina._src.model.pirate_network.PirateNet
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/residual_feed_forward.rst b/docs/source/_rst/model/residual_feed_forward.rst
index 66d83a42c..d8ce08152 100644
--- a/docs/source/_rst/model/residual_feed_forward.rst
+++ b/docs/source/_rst/model/residual_feed_forward.rst
@@ -2,6 +2,6 @@ ResidualFeedForward
======================
.. currentmodule:: pina.model.feed_forward
-.. autoclass:: ResidualFeedForward
+.. autoclass:: pina._src.model.feed_forward.ResidualFeedForward
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/sindy.rst b/docs/source/_rst/model/sindy.rst
index bd507603b..f07ca6d30 100644
--- a/docs/source/_rst/model/sindy.rst
+++ b/docs/source/_rst/model/sindy.rst
@@ -2,6 +2,6 @@ SINDy
=======================
.. currentmodule:: pina.model.sindy
-.. autoclass:: SINDy
+.. autoclass:: pina._src.model.sindy.SINDy
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/spline.rst b/docs/source/_rst/model/spline.rst
index aa7450b70..278a95d3b 100644
--- a/docs/source/_rst/model/spline.rst
+++ b/docs/source/_rst/model/spline.rst
@@ -2,6 +2,6 @@ Spline
========
.. currentmodule:: pina.model.spline
-.. autoclass:: Spline
+.. autoclass:: pina._src.model.spline.Spline
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/spline_surface.rst b/docs/source/_rst/model/spline_surface.rst
index 6bbf137d8..9b204cd22 100644
--- a/docs/source/_rst/model/spline_surface.rst
+++ b/docs/source/_rst/model/spline_surface.rst
@@ -2,6 +2,6 @@ Spline Surface
================
.. currentmodule:: pina.model.spline_surface
-.. autoclass:: SplineSurface
+.. autoclass:: pina._src.model.spline_surface.SplineSurface
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/model/vectorized_spline.rst b/docs/source/_rst/model/vectorized_spline.rst
new file mode 100644
index 000000000..08522bc54
--- /dev/null
+++ b/docs/source/_rst/model/vectorized_spline.rst
@@ -0,0 +1,7 @@
+VectorizedSpline
+=======================
+.. currentmodule:: pina.model.vectorized_spline
+
+.. autoclass:: pina._src.model.vectorized_spline.VectorizedSpline
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/operator.rst b/docs/source/_rst/operator.rst
index 42746a6f8..fe0ad0398 100644
--- a/docs/source/_rst/operator.rst
+++ b/docs/source/_rst/operator.rst
@@ -3,6 +3,7 @@ Operators
.. currentmodule:: pina.operator
-.. automodule:: pina.operator
+
+.. automodule:: pina._src.core.operator
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/optim/optimizer_interface.rst b/docs/source/_rst/optim/optimizer_interface.rst
index 88c18e8f5..23a933bae 100644
--- a/docs/source/_rst/optim/optimizer_interface.rst
+++ b/docs/source/_rst/optim/optimizer_interface.rst
@@ -1,7 +1,7 @@
-Optimizer
-============
+Optimizer Interface
+=====================
.. currentmodule:: pina.optim.optimizer_interface
-.. autoclass:: Optimizer
+.. autoclass:: pina._src.optim.optimizer_interface.OptimizerInterface
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/optim/scheduler_interface.rst b/docs/source/_rst/optim/scheduler_interface.rst
index ab8ee292e..03b3e83f7 100644
--- a/docs/source/_rst/optim/scheduler_interface.rst
+++ b/docs/source/_rst/optim/scheduler_interface.rst
@@ -1,7 +1,7 @@
-Scheduler
-=============
+Scheduler Interface
+=====================
.. currentmodule:: pina.optim.scheduler_interface
-.. autoclass:: Scheduler
+.. autoclass:: pina._src.optim.scheduler_interface.SchedulerInterface
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/optim/torch_optimizer.rst b/docs/source/_rst/optim/torch_optimizer.rst
index 3e6c9d912..54bfe9a3a 100644
--- a/docs/source/_rst/optim/torch_optimizer.rst
+++ b/docs/source/_rst/optim/torch_optimizer.rst
@@ -1,7 +1,7 @@
-TorchOptimizer
+Torch Optimizer
===============
.. currentmodule:: pina.optim.torch_optimizer
-.. autoclass:: TorchOptimizer
+.. autoclass:: pina._src.optim.torch_optimizer.TorchOptimizer
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/optim/torch_scheduler.rst b/docs/source/_rst/optim/torch_scheduler.rst
index 5c3e4df36..59260533e 100644
--- a/docs/source/_rst/optim/torch_scheduler.rst
+++ b/docs/source/_rst/optim/torch_scheduler.rst
@@ -1,7 +1,7 @@
-TorchScheduler
+Torch Scheduler
===============
.. currentmodule:: pina.optim.torch_scheduler
-.. autoclass:: TorchScheduler
+.. autoclass:: pina._src.optim.torch_scheduler.TorchScheduler
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/problem/abstract_problem.rst b/docs/source/_rst/problem/abstract_problem.rst
deleted file mode 100644
index 143909e1b..000000000
--- a/docs/source/_rst/problem/abstract_problem.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-AbstractProblem
-===============
-.. currentmodule:: pina.problem.abstract_problem
-
-.. automodule:: pina.problem.abstract_problem
-
-.. autoclass:: AbstractProblem
- :members:
- :show-inheritance:
diff --git a/docs/source/_rst/problem/base_problem.rst b/docs/source/_rst/problem/base_problem.rst
new file mode 100644
index 000000000..2261a90f7
--- /dev/null
+++ b/docs/source/_rst/problem/base_problem.rst
@@ -0,0 +1,9 @@
+Base Problem
+===============
+.. currentmodule:: pina.problem.base_problem
+
+.. automodule:: pina._src.problem.base_problem
+
+.. autoclass:: pina._src.problem.base_problem.BaseProblem
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/problem/inverse_problem.rst b/docs/source/_rst/problem/inverse_problem.rst
index 5ce306ffc..4b5de05cb 100644
--- a/docs/source/_rst/problem/inverse_problem.rst
+++ b/docs/source/_rst/problem/inverse_problem.rst
@@ -2,8 +2,8 @@ InverseProblem
==============
.. currentmodule:: pina.problem.inverse_problem
-.. automodule:: pina.problem.inverse_problem
+.. automodule:: pina._src.problem.inverse_problem
-.. autoclass:: InverseProblem
+.. autoclass:: pina._src.problem.inverse_problem.InverseProblem
:members:
:show-inheritance:
diff --git a/docs/source/_rst/problem/parametric_problem.rst b/docs/source/_rst/problem/parametric_problem.rst
index 8f217fbbe..1a5e83490 100644
--- a/docs/source/_rst/problem/parametric_problem.rst
+++ b/docs/source/_rst/problem/parametric_problem.rst
@@ -2,8 +2,8 @@ ParametricProblem
====================
.. currentmodule:: pina.problem.parametric_problem
-.. automodule:: pina.problem.parametric_problem
+.. automodule:: pina._src.problem.parametric_problem
-.. autoclass:: ParametricProblem
+.. autoclass:: pina._src.problem.parametric_problem.ParametricProblem
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/problem/problem_interface.rst b/docs/source/_rst/problem/problem_interface.rst
new file mode 100644
index 000000000..08136e23c
--- /dev/null
+++ b/docs/source/_rst/problem/problem_interface.rst
@@ -0,0 +1,9 @@
+ProblemInterface
+===================
+.. currentmodule:: pina.problem.problem_interface
+
+.. automodule:: pina._src.problem.problem_interface
+
+.. autoclass:: pina._src.problem.problem_interface.ProblemInterface
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/problem/spatial_problem.rst b/docs/source/_rst/problem/spatial_problem.rst
index 90ec6ec3c..757243ef1 100644
--- a/docs/source/_rst/problem/spatial_problem.rst
+++ b/docs/source/_rst/problem/spatial_problem.rst
@@ -2,8 +2,8 @@ SpatialProblem
==============
.. currentmodule:: pina.problem.spatial_problem
-.. automodule:: pina.problem.spatial_problem
+.. automodule:: pina._src.problem.spatial_problem
-.. autoclass:: SpatialProblem
+.. autoclass:: pina._src.problem.spatial_problem.SpatialProblem
:members:
:show-inheritance:
diff --git a/docs/source/_rst/problem/time_dependent_problem.rst b/docs/source/_rst/problem/time_dependent_problem.rst
index db94121c2..dda1e07f1 100644
--- a/docs/source/_rst/problem/time_dependent_problem.rst
+++ b/docs/source/_rst/problem/time_dependent_problem.rst
@@ -2,8 +2,8 @@ TimeDependentProblem
====================
.. currentmodule:: pina.problem.time_dependent_problem
-.. automodule:: pina.problem.time_dependent_problem
+.. automodule:: pina._src.problem.time_dependent_problem
-.. autoclass:: TimeDependentProblem
+.. autoclass:: pina._src.problem.time_dependent_problem.TimeDependentProblem
:members:
:show-inheritance:
diff --git a/docs/source/_rst/problem/zoo/acoustic_wave.rst b/docs/source/_rst/problem/zoo/acoustic_wave.rst
deleted file mode 100644
index 4a9489667..000000000
--- a/docs/source/_rst/problem/zoo/acoustic_wave.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-AcousticWaveProblem
-=====================
-.. currentmodule:: pina.problem.zoo.acoustic_wave
-
-.. automodule:: pina.problem.zoo.acoustic_wave
-
-.. autoclass:: AcousticWaveProblem
- :members:
- :show-inheritance:
diff --git a/docs/source/_rst/problem/zoo/acoustic_wave_problem.rst b/docs/source/_rst/problem/zoo/acoustic_wave_problem.rst
new file mode 100644
index 000000000..c6acb93f1
--- /dev/null
+++ b/docs/source/_rst/problem/zoo/acoustic_wave_problem.rst
@@ -0,0 +1,9 @@
+AcousticWaveProblem
+=====================
+.. currentmodule:: pina.problem.zoo.acoustic_wave_problem
+
+.. automodule:: pina._src.problem.zoo.acoustic_wave_problem
+
+.. autoclass:: pina._src.problem.zoo.acoustic_wave_problem.AcousticWaveProblem
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/problem/zoo/advection.rst b/docs/source/_rst/problem/zoo/advection.rst
deleted file mode 100644
index b83cc9d99..000000000
--- a/docs/source/_rst/problem/zoo/advection.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-AdvectionProblem
-==================
-.. currentmodule:: pina.problem.zoo.advection
-
-.. automodule:: pina.problem.zoo.advection
-
-.. autoclass:: AdvectionProblem
- :members:
- :show-inheritance:
diff --git a/docs/source/_rst/problem/zoo/advection_problem.rst b/docs/source/_rst/problem/zoo/advection_problem.rst
new file mode 100644
index 000000000..df37679cb
--- /dev/null
+++ b/docs/source/_rst/problem/zoo/advection_problem.rst
@@ -0,0 +1,9 @@
+AdvectionProblem
+==================
+.. currentmodule:: pina.problem.zoo.advection_problem
+
+.. automodule:: pina._src.problem.zoo.advection_problem
+
+.. autoclass:: pina._src.problem.zoo.advection_problem.AdvectionProblem
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/problem/zoo/allen_cahn.rst b/docs/source/_rst/problem/zoo/allen_cahn.rst
deleted file mode 100644
index ada3465d1..000000000
--- a/docs/source/_rst/problem/zoo/allen_cahn.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-AllenCahnProblem
-==================
-.. currentmodule:: pina.problem.zoo.allen_cahn
-
-.. automodule:: pina.problem.zoo.allen_cahn
-
-.. autoclass:: AllenCahnProblem
- :members:
- :show-inheritance:
diff --git a/docs/source/_rst/problem/zoo/allen_cahn_problem.rst b/docs/source/_rst/problem/zoo/allen_cahn_problem.rst
new file mode 100644
index 000000000..463be3a55
--- /dev/null
+++ b/docs/source/_rst/problem/zoo/allen_cahn_problem.rst
@@ -0,0 +1,9 @@
+AllenCahnProblem
+==================
+.. currentmodule:: pina.problem.zoo.allen_cahn_problem
+
+.. automodule:: pina._src.problem.zoo.allen_cahn_problem
+
+.. autoclass:: pina._src.problem.zoo.allen_cahn_problem.AllenCahnProblem
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/problem/zoo/burgers_problem.rst b/docs/source/_rst/problem/zoo/burgers_problem.rst
new file mode 100644
index 000000000..75151d8d8
--- /dev/null
+++ b/docs/source/_rst/problem/zoo/burgers_problem.rst
@@ -0,0 +1,9 @@
+Burgers' Problem
+=====================
+.. currentmodule:: pina.problem.zoo.burgers_problem
+
+.. automodule:: pina._src.problem.zoo.burgers_problem
+
+.. autoclass:: pina._src.problem.zoo.burgers_problem.BurgersProblem
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/problem/zoo/diffusion_reaction.rst b/docs/source/_rst/problem/zoo/diffusion_reaction.rst
deleted file mode 100644
index 0cad0fd67..000000000
--- a/docs/source/_rst/problem/zoo/diffusion_reaction.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-DiffusionReactionProblem
-=========================
-.. currentmodule:: pina.problem.zoo.diffusion_reaction
-
-.. automodule:: pina.problem.zoo.diffusion_reaction
-
-.. autoclass:: DiffusionReactionProblem
- :members:
- :show-inheritance:
diff --git a/docs/source/_rst/problem/zoo/diffusion_reaction_problem.rst b/docs/source/_rst/problem/zoo/diffusion_reaction_problem.rst
new file mode 100644
index 000000000..307a56c52
--- /dev/null
+++ b/docs/source/_rst/problem/zoo/diffusion_reaction_problem.rst
@@ -0,0 +1,9 @@
+DiffusionReactionProblem
+=========================
+.. currentmodule:: pina.problem.zoo.diffusion_reaction_problem
+
+.. automodule:: pina._src.problem.zoo.diffusion_reaction_problem
+
+.. autoclass:: pina._src.problem.zoo.diffusion_reaction_problem.DiffusionReactionProblem
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/problem/zoo/helmholtz.rst b/docs/source/_rst/problem/zoo/helmholtz.rst
deleted file mode 100644
index af4ec7dbc..000000000
--- a/docs/source/_rst/problem/zoo/helmholtz.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-HelmholtzProblem
-==================
-.. currentmodule:: pina.problem.zoo.helmholtz
-
-.. automodule:: pina.problem.zoo.helmholtz
-
-.. autoclass:: HelmholtzProblem
- :members:
- :show-inheritance:
diff --git a/docs/source/_rst/problem/zoo/helmholtz_problem.rst b/docs/source/_rst/problem/zoo/helmholtz_problem.rst
new file mode 100644
index 000000000..952578a2b
--- /dev/null
+++ b/docs/source/_rst/problem/zoo/helmholtz_problem.rst
@@ -0,0 +1,9 @@
+HelmholtzProblem
+==================
+.. currentmodule:: pina.problem.zoo.helmholtz_problem
+
+.. automodule:: pina._src.problem.zoo.helmholtz_problem
+
+.. autoclass:: pina._src.problem.zoo.helmholtz_problem.HelmholtzProblem
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/problem/zoo/inverse_poisson_2d_square.rst b/docs/source/_rst/problem/zoo/inverse_poisson_2d_square.rst
deleted file mode 100644
index 727c17b47..000000000
--- a/docs/source/_rst/problem/zoo/inverse_poisson_2d_square.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-InversePoisson2DSquareProblem
-==============================
-.. currentmodule:: pina.problem.zoo.inverse_poisson_2d_square
-
-.. automodule:: pina.problem.zoo.inverse_poisson_2d_square
-
-.. autoclass:: InversePoisson2DSquareProblem
- :members:
- :show-inheritance:
diff --git a/docs/source/_rst/problem/zoo/inverse_poisson_problem.rst b/docs/source/_rst/problem/zoo/inverse_poisson_problem.rst
new file mode 100644
index 000000000..503eb21bf
--- /dev/null
+++ b/docs/source/_rst/problem/zoo/inverse_poisson_problem.rst
@@ -0,0 +1,9 @@
+InversePoisson2DSquareProblem
+==============================
+.. currentmodule:: pina.problem.zoo.inverse_poisson_problem
+
+.. automodule:: pina._src.problem.zoo.inverse_poisson_problem
+
+.. autoclass:: pina._src.problem.zoo.inverse_poisson_problem.InversePoisson2DSquareProblem
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/problem/zoo/poisson_2d_square.rst b/docs/source/_rst/problem/zoo/poisson_2d_square.rst
deleted file mode 100644
index 718c33ccc..000000000
--- a/docs/source/_rst/problem/zoo/poisson_2d_square.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-Poisson2DSquareProblem
-========================
-.. currentmodule:: pina.problem.zoo.poisson_2d_square
-
-.. automodule:: pina.problem.zoo.poisson_2d_square
-
-.. autoclass:: Poisson2DSquareProblem
- :members:
- :show-inheritance:
diff --git a/docs/source/_rst/problem/zoo/poisson_problem.rst b/docs/source/_rst/problem/zoo/poisson_problem.rst
new file mode 100644
index 000000000..a480a8953
--- /dev/null
+++ b/docs/source/_rst/problem/zoo/poisson_problem.rst
@@ -0,0 +1,9 @@
+Poisson2DSquareProblem
+========================
+.. currentmodule:: pina.problem.zoo.poisson_problem
+
+.. automodule:: pina._src.problem.zoo.poisson_problem
+
+.. autoclass:: pina._src.problem.zoo.poisson_problem.Poisson2DSquareProblem
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/problem/zoo/supervised_problem.rst b/docs/source/_rst/problem/zoo/supervised_problem.rst
index aad7d5aa5..6bf368376 100644
--- a/docs/source/_rst/problem/zoo/supervised_problem.rst
+++ b/docs/source/_rst/problem/zoo/supervised_problem.rst
@@ -2,8 +2,8 @@ SupervisedProblem
==================
.. currentmodule:: pina.problem.zoo.supervised_problem
-.. automodule:: pina.problem.zoo.supervised_problem
+.. automodule:: pina._src.problem.zoo.supervised_problem
-.. autoclass:: SupervisedProblem
+.. autoclass:: pina._src.problem.zoo.supervised_problem.SupervisedProblem
:members:
:show-inheritance:
diff --git a/docs/source/_rst/solver/autoregressive_ensemble_solver.rst b/docs/source/_rst/solver/autoregressive_ensemble_solver.rst
new file mode 100644
index 000000000..ba90c826f
--- /dev/null
+++ b/docs/source/_rst/solver/autoregressive_ensemble_solver.rst
@@ -0,0 +1,10 @@
+Autoregressive Ensemble Solver
+=================================
+.. currentmodule:: pina.solver.autoregressive_ensemble_solver
+
+.. automodule:: pina._src.solver.autoregressive_ensemble_solver
+
+.. autoclass:: pina._src.solver.autoregressive_ensemble_solver.AutoregressiveEnsembleSolver
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/autoregressive_single_model_solver.rst b/docs/source/_rst/solver/autoregressive_single_model_solver.rst
new file mode 100644
index 000000000..217c1ff59
--- /dev/null
+++ b/docs/source/_rst/solver/autoregressive_single_model_solver.rst
@@ -0,0 +1,10 @@
+Autoregressive Single Model Solver
+======================================
+.. currentmodule:: pina.solver.autoregressive_single_model_solver
+
+.. automodule:: pina._src.solver.autoregressive_single_model_solver
+
+.. autoclass:: pina._src.solver.autoregressive_single_model_solver.AutoregressiveSingleModelSolver
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/base_solver.rst b/docs/source/_rst/solver/base_solver.rst
new file mode 100644
index 000000000..939b94311
--- /dev/null
+++ b/docs/source/_rst/solver/base_solver.rst
@@ -0,0 +1,10 @@
+Base Solver
+=================================
+.. currentmodule:: pina.solver.base_solver
+
+.. automodule:: pina._src.solver.base_solver
+
+.. autoclass:: pina._src.solver.base_solver.BaseSolver
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/causal_physics_informed_single_model_solver.rst b/docs/source/_rst/solver/causal_physics_informed_single_model_solver.rst
new file mode 100644
index 000000000..811231ae1
--- /dev/null
+++ b/docs/source/_rst/solver/causal_physics_informed_single_model_solver.rst
@@ -0,0 +1,11 @@
+Causal Physics Informed Single Model Solver
+=================================================
+
+.. currentmodule:: pina.solver.causal_physics_informed_single_model_solver
+
+.. automodule:: pina._src.solver.causal_physics_informed_single_model_solver
+
+.. autoclass:: pina._src.solver.causal_physics_informed_single_model_solver.CausalPhysicsInformedSingleModelSolver
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/competitive_physics_informed_solver.rst b/docs/source/_rst/solver/competitive_physics_informed_solver.rst
new file mode 100644
index 000000000..9138dca85
--- /dev/null
+++ b/docs/source/_rst/solver/competitive_physics_informed_solver.rst
@@ -0,0 +1,11 @@
+Competitive Physics-Informed Solver
+=======================================
+
+.. currentmodule:: pina.solver.competitive_physics_informed_solver
+
+.. automodule:: pina._src.solver.competitive_physics_informed_solver
+
+.. autoclass:: pina._src.solver.competitive_physics_informed_solver.CompetitivePhysicsInformedSolver
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/ensemble_solver.rst b/docs/source/_rst/solver/ensemble_solver.rst
new file mode 100644
index 000000000..1031422c0
--- /dev/null
+++ b/docs/source/_rst/solver/ensemble_solver.rst
@@ -0,0 +1,10 @@
+Ensemble Solver
+=================================
+.. currentmodule:: pina.solver.ensemble_solver
+
+.. automodule:: pina._src.solver.ensemble_solver
+
+.. autoclass:: pina._src.solver.ensemble_solver.EnsembleSolver
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/ensemble_solver/ensemble_pinn.rst b/docs/source/_rst/solver/ensemble_solver/ensemble_pinn.rst
deleted file mode 100644
index 2e42dcf0d..000000000
--- a/docs/source/_rst/solver/ensemble_solver/ensemble_pinn.rst
+++ /dev/null
@@ -1,8 +0,0 @@
-DeepEnsemblePINN
-==================
-.. currentmodule:: pina.solver.ensemble_solver.ensemble_pinn
-
-.. autoclass:: DeepEnsemblePINN
- :show-inheritance:
- :members:
-
diff --git a/docs/source/_rst/solver/ensemble_solver/ensemble_solver_interface.rst b/docs/source/_rst/solver/ensemble_solver/ensemble_solver_interface.rst
deleted file mode 100644
index 664bb8c8f..000000000
--- a/docs/source/_rst/solver/ensemble_solver/ensemble_solver_interface.rst
+++ /dev/null
@@ -1,8 +0,0 @@
-DeepEnsembleSolverInterface
-=============================
-.. currentmodule:: pina.solver.ensemble_solver.ensemble_solver_interface
-
-.. autoclass:: DeepEnsembleSolverInterface
- :show-inheritance:
- :members:
-
diff --git a/docs/source/_rst/solver/ensemble_solver/ensemble_supervised.rst b/docs/source/_rst/solver/ensemble_solver/ensemble_supervised.rst
deleted file mode 100644
index 575b28594..000000000
--- a/docs/source/_rst/solver/ensemble_solver/ensemble_supervised.rst
+++ /dev/null
@@ -1,8 +0,0 @@
-DeepEnsembleSupervisedSolver
-=============================
-.. currentmodule:: pina.solver.ensemble_solver.ensemble_supervised
-
-.. autoclass:: DeepEnsembleSupervisedSolver
- :show-inheritance:
- :members:
-
diff --git a/docs/source/_rst/solver/garom.rst b/docs/source/_rst/solver/garom.rst
deleted file mode 100644
index 0e5820f6f..000000000
--- a/docs/source/_rst/solver/garom.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-GAROM
-======
-.. currentmodule:: pina.solver.garom
-
-.. autoclass:: GAROM
- :members:
- :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/solver/gradient_physics_informed_single_model_solver.rst b/docs/source/_rst/solver/gradient_physics_informed_single_model_solver.rst
new file mode 100644
index 000000000..b602d2277
--- /dev/null
+++ b/docs/source/_rst/solver/gradient_physics_informed_single_model_solver.rst
@@ -0,0 +1,11 @@
+Gradient Physics Informed Single Model Solver
+=================================================
+
+.. currentmodule:: pina.solver.gradient_physics_informed_single_model_solver
+
+.. automodule:: pina._src.solver.gradient_physics_informed_single_model_solver
+
+.. autoclass:: pina._src.solver.gradient_physics_informed_single_model_solver.GradientPhysicsInformedSingleModelSolver
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/mixin/autoregressive_mixin.rst b/docs/source/_rst/solver/mixin/autoregressive_mixin.rst
new file mode 100644
index 000000000..fee7df2ac
--- /dev/null
+++ b/docs/source/_rst/solver/mixin/autoregressive_mixin.rst
@@ -0,0 +1,11 @@
+Autoregressive Mixin
+=================================
+
+.. currentmodule:: pina.solver.mixin.autoregressive_mixin
+
+.. automodule:: pina._src.solver.mixin.autoregressive_mixin
+
+.. autoclass:: pina._src.solver.mixin.autoregressive_mixin.AutoregressiveMixin
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/mixin/condition_aggregator_mixin.rst b/docs/source/_rst/solver/mixin/condition_aggregator_mixin.rst
new file mode 100644
index 000000000..4868edff1
--- /dev/null
+++ b/docs/source/_rst/solver/mixin/condition_aggregator_mixin.rst
@@ -0,0 +1,11 @@
+Condition Aggregator Mixin
+=================================
+
+.. currentmodule:: pina.solver.mixin.condition_aggregator_mixin
+
+.. automodule:: pina._src.solver.mixin.condition_aggregator_mixin
+
+.. autoclass:: pina._src.solver.mixin.condition_aggregator_mixin.ConditionAggregatorMixin
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/mixin/ensemble_mixin.rst b/docs/source/_rst/solver/mixin/ensemble_mixin.rst
new file mode 100644
index 000000000..d3548e745
--- /dev/null
+++ b/docs/source/_rst/solver/mixin/ensemble_mixin.rst
@@ -0,0 +1,11 @@
+Ensemble Mixin
+=================================
+
+.. currentmodule:: pina.solver.mixin.ensemble_mixin
+
+.. automodule:: pina._src.solver.mixin.ensemble_mixin
+
+.. autoclass:: pina._src.solver.mixin.ensemble_mixin.EnsembleMixin
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/mixin/gradient_enhanced_mixin.rst b/docs/source/_rst/solver/mixin/gradient_enhanced_mixin.rst
new file mode 100644
index 000000000..f9ab310b6
--- /dev/null
+++ b/docs/source/_rst/solver/mixin/gradient_enhanced_mixin.rst
@@ -0,0 +1,11 @@
+Gradient-Enhanced Mixin
+=================================
+
+.. currentmodule:: pina.solver.mixin.gradient_enhanced_mixin
+
+.. automodule:: pina._src.solver.mixin.gradient_enhanced_mixin
+
+.. autoclass:: pina._src.solver.mixin.gradient_enhanced_mixin.GradientEnhancedMixin
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/mixin/manual_optimization_mixin.rst b/docs/source/_rst/solver/mixin/manual_optimization_mixin.rst
new file mode 100644
index 000000000..5974aa222
--- /dev/null
+++ b/docs/source/_rst/solver/mixin/manual_optimization_mixin.rst
@@ -0,0 +1,11 @@
+Manual Optimization Mixin
+=================================
+
+.. currentmodule:: pina.solver.mixin.manual_optimization_mixin
+
+.. automodule:: pina._src.solver.mixin.manual_optimization_mixin
+
+.. autoclass:: pina._src.solver.mixin.manual_optimization_mixin.ManualOptimizationMixin
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/mixin/multi_model_mixin.rst b/docs/source/_rst/solver/mixin/multi_model_mixin.rst
new file mode 100644
index 000000000..0302d1883
--- /dev/null
+++ b/docs/source/_rst/solver/mixin/multi_model_mixin.rst
@@ -0,0 +1,11 @@
+Multi-Model Mixin
+=================================
+
+.. currentmodule:: pina.solver.mixin.multi_model_mixin
+
+.. automodule:: pina._src.solver.mixin.multi_model_mixin
+
+.. autoclass:: pina._src.solver.mixin.multi_model_mixin.MultiModelMixin
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/mixin/physics_informed_mixin.rst b/docs/source/_rst/solver/mixin/physics_informed_mixin.rst
new file mode 100644
index 000000000..8503d9cbc
--- /dev/null
+++ b/docs/source/_rst/solver/mixin/physics_informed_mixin.rst
@@ -0,0 +1,11 @@
+Physics-Informed Mixin
+=================================
+
+.. currentmodule:: pina.solver.mixin.physics_informed_mixin
+
+.. automodule:: pina._src.solver.mixin.physics_informed_mixin
+
+.. autoclass:: pina._src.solver.mixin.physics_informed_mixin.PhysicsInformedMixin
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/mixin/residual_based_attention_mixin.rst b/docs/source/_rst/solver/mixin/residual_based_attention_mixin.rst
new file mode 100644
index 000000000..768c108a8
--- /dev/null
+++ b/docs/source/_rst/solver/mixin/residual_based_attention_mixin.rst
@@ -0,0 +1,11 @@
+Residual-Based Attention Mixin
+=================================
+
+.. currentmodule:: pina.solver.mixin.residual_based_attention_mixin
+
+.. automodule:: pina._src.solver.mixin.residual_based_attention_mixin
+
+.. autoclass:: pina._src.solver.mixin.residual_based_attention_mixin.ResidualBasedAttentionMixin
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/mixin/single_model_mixin.rst b/docs/source/_rst/solver/mixin/single_model_mixin.rst
new file mode 100644
index 000000000..c7c665793
--- /dev/null
+++ b/docs/source/_rst/solver/mixin/single_model_mixin.rst
@@ -0,0 +1,11 @@
+Single-Model Mixin
+=================================
+
+.. currentmodule:: pina.solver.mixin.single_model_mixin
+
+.. automodule:: pina._src.solver.mixin.single_model_mixin
+
+.. autoclass:: pina._src.solver.mixin.single_model_mixin.SingleModelMixin
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/multi_model_solver.rst b/docs/source/_rst/solver/multi_model_solver.rst
new file mode 100644
index 000000000..37e1cf4df
--- /dev/null
+++ b/docs/source/_rst/solver/multi_model_solver.rst
@@ -0,0 +1,10 @@
+Multi Model Solver
+=================================
+.. currentmodule:: pina.solver.multi_model_solver
+
+.. automodule:: pina._src.solver.multi_model_solver
+
+.. autoclass:: pina._src.solver.multi_model_solver.MultiModelSolver
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/multi_solver_interface.rst b/docs/source/_rst/solver/multi_solver_interface.rst
deleted file mode 100644
index 7f68c83a4..000000000
--- a/docs/source/_rst/solver/multi_solver_interface.rst
+++ /dev/null
@@ -1,8 +0,0 @@
-MultiSolverInterface
-======================
-.. currentmodule:: pina.solver.solver
-
-.. autoclass:: MultiSolverInterface
- :show-inheritance:
- :members:
-
diff --git a/docs/source/_rst/solver/physics_informed_ensemble_solver.rst b/docs/source/_rst/solver/physics_informed_ensemble_solver.rst
new file mode 100644
index 000000000..726da79f8
--- /dev/null
+++ b/docs/source/_rst/solver/physics_informed_ensemble_solver.rst
@@ -0,0 +1,10 @@
+Physics Informed Ensemble Solver
+=================================
+.. currentmodule:: pina.solver.physics_informed_ensemble_solver
+
+.. automodule:: pina._src.solver.physics_informed_ensemble_solver
+
+.. autoclass:: pina._src.solver.physics_informed_ensemble_solver.PhysicsInformedEnsembleSolver
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/physics_informed_single_model_solver.rst b/docs/source/_rst/solver/physics_informed_single_model_solver.rst
new file mode 100644
index 000000000..38f5952d6
--- /dev/null
+++ b/docs/source/_rst/solver/physics_informed_single_model_solver.rst
@@ -0,0 +1,10 @@
+Physics Informed Single Model Solver
+=======================================
+.. currentmodule:: pina.solver.physics_informed_single_model_solver
+
+.. automodule:: pina._src.solver.physics_informed_single_model_solver
+
+.. autoclass:: pina._src.solver.physics_informed_single_model_solver.PhysicsInformedSingleModelSolver
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/physics_informed_solver/causal_pinn.rst b/docs/source/_rst/solver/physics_informed_solver/causal_pinn.rst
deleted file mode 100644
index 6fab9ef0e..000000000
--- a/docs/source/_rst/solver/physics_informed_solver/causal_pinn.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-CausalPINN
-==============
-.. currentmodule:: pina.solver.physics_informed_solver.causal_pinn
-
-.. autoclass:: CausalPINN
- :members:
- :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/solver/physics_informed_solver/competitive_pinn.rst b/docs/source/_rst/solver/physics_informed_solver/competitive_pinn.rst
deleted file mode 100644
index 372cb0f3d..000000000
--- a/docs/source/_rst/solver/physics_informed_solver/competitive_pinn.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-CompetitivePINN
-=================
-.. currentmodule:: pina.solver.physics_informed_solver.competitive_pinn
-
-.. autoclass:: CompetitivePINN
- :members:
- :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/solver/physics_informed_solver/gradient_pinn.rst b/docs/source/_rst/solver/physics_informed_solver/gradient_pinn.rst
deleted file mode 100644
index 66a490013..000000000
--- a/docs/source/_rst/solver/physics_informed_solver/gradient_pinn.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-GradientPINN
-==============
-.. currentmodule:: pina.solver.physics_informed_solver.gradient_pinn
-
-.. autoclass:: GradientPINN
- :members:
- :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/solver/physics_informed_solver/pinn.rst b/docs/source/_rst/solver/physics_informed_solver/pinn.rst
deleted file mode 100644
index fdc31253b..000000000
--- a/docs/source/_rst/solver/physics_informed_solver/pinn.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-PINN
-======
-.. currentmodule:: pina.solver.physics_informed_solver.pinn
-
-.. autoclass:: PINN
- :members:
- :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/solver/physics_informed_solver/pinn_interface.rst b/docs/source/_rst/solver/physics_informed_solver/pinn_interface.rst
deleted file mode 100644
index 2242cf8b4..000000000
--- a/docs/source/_rst/solver/physics_informed_solver/pinn_interface.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-PINNInterface
-=================
-.. currentmodule:: pina.solver.physics_informed_solver.pinn_interface
-
-.. autoclass:: PINNInterface
- :members:
- :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/solver/physics_informed_solver/rba_pinn.rst b/docs/source/_rst/solver/physics_informed_solver/rba_pinn.rst
deleted file mode 100644
index cf94b6df0..000000000
--- a/docs/source/_rst/solver/physics_informed_solver/rba_pinn.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-RBAPINN
-========
-.. currentmodule:: pina.solver.physics_informed_solver.rba_pinn
-
-.. autoclass:: RBAPINN
- :members:
- :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/solver/physics_informed_solver/self_adaptive_pinn.rst b/docs/source/_rst/solver/physics_informed_solver/self_adaptive_pinn.rst
deleted file mode 100644
index 2290059bd..000000000
--- a/docs/source/_rst/solver/physics_informed_solver/self_adaptive_pinn.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-SelfAdaptivePINN
-==================
-.. currentmodule:: pina.solver.physics_informed_solver.self_adaptive_pinn
-
-.. autoclass:: SelfAdaptivePINN
- :members:
- :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/solver/rba_physics_informed_single_model_solver.rst b/docs/source/_rst/solver/rba_physics_informed_single_model_solver.rst
new file mode 100644
index 000000000..7765d2d95
--- /dev/null
+++ b/docs/source/_rst/solver/rba_physics_informed_single_model_solver.rst
@@ -0,0 +1,11 @@
+RBA Physics-Informed Single-Model Solver
+=================================================
+
+.. currentmodule:: pina.solver.rba_physics_informed_single_model_solver
+
+.. automodule:: pina._src.solver.rba_physics_informed_single_model_solver
+
+.. autoclass:: pina._src.solver.rba_physics_informed_single_model_solver.RBAPhysicsInformedSingleModelSolver
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/self_adaptive_physics_informed_solver.rst b/docs/source/_rst/solver/self_adaptive_physics_informed_solver.rst
new file mode 100644
index 000000000..901520e1b
--- /dev/null
+++ b/docs/source/_rst/solver/self_adaptive_physics_informed_solver.rst
@@ -0,0 +1,11 @@
+Self-Adaptive Physics-Informed Solver
+=======================================
+
+.. currentmodule:: pina.solver.self_adaptive_physics_informed_solver
+
+.. automodule:: pina._src.solver.self_adaptive_physics_informed_solver
+
+.. autoclass:: pina._src.solver.self_adaptive_physics_informed_solver.SelfAdaptivePhysicsInformedSolver
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/single_model_solver.rst b/docs/source/_rst/solver/single_model_solver.rst
new file mode 100644
index 000000000..7bb4857d5
--- /dev/null
+++ b/docs/source/_rst/solver/single_model_solver.rst
@@ -0,0 +1,10 @@
+Single Model Solver
+=================================
+.. currentmodule:: pina.solver.single_model_solver
+
+.. automodule:: pina._src.solver.single_model_solver
+
+.. autoclass:: pina._src.solver.single_model_solver.SingleModelSolver
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/single_solver_interface.rst b/docs/source/_rst/solver/single_solver_interface.rst
deleted file mode 100644
index 5b85f11b5..000000000
--- a/docs/source/_rst/solver/single_solver_interface.rst
+++ /dev/null
@@ -1,8 +0,0 @@
-SingleSolverInterface
-======================
-.. currentmodule:: pina.solver.solver
-
-.. autoclass:: SingleSolverInterface
- :show-inheritance:
- :members:
-
diff --git a/docs/source/_rst/solver/solver_interface.rst b/docs/source/_rst/solver/solver_interface.rst
index 9bb11783e..b9f4b9e66 100644
--- a/docs/source/_rst/solver/solver_interface.rst
+++ b/docs/source/_rst/solver/solver_interface.rst
@@ -1,8 +1,10 @@
-SolverInterface
-=================
-.. currentmodule:: pina.solver.solver
+Solver Interface
+=================================
+.. currentmodule:: pina.solver.solver_interface
-.. autoclass:: SolverInterface
- :show-inheritance:
+.. automodule:: pina._src.solver.solver_interface
+
+.. autoclass:: pina._src.solver.solver_interface.SolverInterface
:members:
-
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/supervised_ensemble_solver.rst b/docs/source/_rst/solver/supervised_ensemble_solver.rst
new file mode 100644
index 000000000..23f276640
--- /dev/null
+++ b/docs/source/_rst/solver/supervised_ensemble_solver.rst
@@ -0,0 +1,10 @@
+Supervised Ensemble Solver
+=================================
+.. currentmodule:: pina.solver.supervised_ensemble_solver
+
+.. automodule:: pina._src.solver.supervised_ensemble_solver
+
+.. autoclass:: pina._src.solver.supervised_ensemble_solver.SupervisedEnsembleSolver
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/supervised_single_model_solver.rst b/docs/source/_rst/solver/supervised_single_model_solver.rst
new file mode 100644
index 000000000..13c3b2fc0
--- /dev/null
+++ b/docs/source/_rst/solver/supervised_single_model_solver.rst
@@ -0,0 +1,11 @@
+Supervised Single Model Solver
+=================================
+
+.. currentmodule:: pina.solver.supervised_single_model_solver
+
+.. automodule:: pina._src.solver.supervised_single_model_solver
+
+.. autoclass:: pina._src.solver.supervised_single_model_solver.SupervisedSingleModelSolver
+ :members:
+ :show-inheritance:
+ :noindex:
diff --git a/docs/source/_rst/solver/supervised_solver/reduced_order_model.rst b/docs/source/_rst/solver/supervised_solver/reduced_order_model.rst
deleted file mode 100644
index 878014c29..000000000
--- a/docs/source/_rst/solver/supervised_solver/reduced_order_model.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-ReducedOrderModelSolver
-==========================
-.. currentmodule:: pina.solver.supervised_solver.reduced_order_model
-
-.. autoclass:: ReducedOrderModelSolver
- :members:
- :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/solver/supervised_solver/supervised.rst b/docs/source/_rst/solver/supervised_solver/supervised.rst
deleted file mode 100644
index 60ffdf828..000000000
--- a/docs/source/_rst/solver/supervised_solver/supervised.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-SupervisedSolver
-===================
-.. currentmodule:: pina.solver.supervised_solver.supervised
-
-.. autoclass:: SupervisedSolver
- :members:
- :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/solver/supervised_solver/supervised_solver_interface.rst b/docs/source/_rst/solver/supervised_solver/supervised_solver_interface.rst
deleted file mode 100644
index 4903a18dd..000000000
--- a/docs/source/_rst/solver/supervised_solver/supervised_solver_interface.rst
+++ /dev/null
@@ -1,8 +0,0 @@
-SupervisedSolverInterface
-==========================
-.. currentmodule:: pina.solver.supervised_solver.supervised_solver_interface
-
-.. autoclass:: SupervisedSolverInterface
- :show-inheritance:
- :members:
-
diff --git a/docs/source/_rst/trainer.rst b/docs/source/_rst/trainer.rst
index 2582b6da9..8e5a99a38 100644
--- a/docs/source/_rst/trainer.rst
+++ b/docs/source/_rst/trainer.rst
@@ -3,6 +3,6 @@ Trainer
.. automodule:: pina.trainer
-.. autoclass:: Trainer
+.. autoclass:: pina._src.core.trainer.Trainer
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/weighting/base_weighting.rst b/docs/source/_rst/weighting/base_weighting.rst
new file mode 100644
index 000000000..c8544697d
--- /dev/null
+++ b/docs/source/_rst/weighting/base_weighting.rst
@@ -0,0 +1,9 @@
+BaseWeighting
+===================
+.. currentmodule:: pina.weighting.base_weighting
+
+.. automodule:: pina._src.weighting.base_weighting
+
+.. autoclass:: pina._src.weighting.base_weighting.BaseWeighting
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/weighting/linear_weighting.rst b/docs/source/_rst/weighting/linear_weighting.rst
new file mode 100644
index 000000000..1941bbe80
--- /dev/null
+++ b/docs/source/_rst/weighting/linear_weighting.rst
@@ -0,0 +1,11 @@
+LinearWeighting
+=============================
+
+.. currentmodule:: pina.weighting
+
+.. automodule:: pina._src.weighting.linear_weighting
+ :no-members:
+
+.. autoclass:: pina._src.weighting.linear_weighting.LinearWeighting
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/weighting/no_weighting.rst b/docs/source/_rst/weighting/no_weighting.rst
new file mode 100644
index 000000000..f6794eb5c
--- /dev/null
+++ b/docs/source/_rst/weighting/no_weighting.rst
@@ -0,0 +1,9 @@
+No Weighting
+===================
+.. currentmodule:: pina.weighting.no_weighting
+
+.. automodule:: pina._src.weighting.no_weighting
+
+.. autoclass:: pina._src.weighting.no_weighting._NoWeighting
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/weighting/ntk_weighting.rst b/docs/source/_rst/weighting/ntk_weighting.rst
new file mode 100644
index 000000000..acee58fa2
--- /dev/null
+++ b/docs/source/_rst/weighting/ntk_weighting.rst
@@ -0,0 +1,9 @@
+NeuralTangentKernelWeighting
+=============================
+.. currentmodule:: pina.weighting.ntk_weighting
+
+.. automodule:: pina._src.weighting.ntk_weighting
+
+.. autoclass:: pina._src.weighting.ntk_weighting.NeuralTangentKernelWeighting
+ :members:
+ :show-inheritance:
diff --git a/docs/source/_rst/weighting/scalar_weighting.rst b/docs/source/_rst/weighting/scalar_weighting.rst
new file mode 100644
index 000000000..712425086
--- /dev/null
+++ b/docs/source/_rst/weighting/scalar_weighting.rst
@@ -0,0 +1,9 @@
+ScalarWeighting
+===================
+.. currentmodule:: pina.weighting.scalar_weighting
+
+.. automodule:: pina._src.weighting.scalar_weighting
+
+.. autoclass:: pina._src.weighting.scalar_weighting.ScalarWeighting
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/weighting/self_adaptive_weighting.rst b/docs/source/_rst/weighting/self_adaptive_weighting.rst
new file mode 100644
index 000000000..32ed13aba
--- /dev/null
+++ b/docs/source/_rst/weighting/self_adaptive_weighting.rst
@@ -0,0 +1,9 @@
+SelfAdaptiveWeighting
+=============================
+.. currentmodule:: pina.weighting.self_adaptive_weighting
+
+.. automodule:: pina._src.weighting.self_adaptive_weighting
+
+.. autoclass:: pina._src.weighting.self_adaptive_weighting.SelfAdaptiveWeighting
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/_rst/weighting/weighting_interface.rst b/docs/source/_rst/weighting/weighting_interface.rst
new file mode 100644
index 000000000..19cf34b42
--- /dev/null
+++ b/docs/source/_rst/weighting/weighting_interface.rst
@@ -0,0 +1,9 @@
+WeightingInterface
+===================
+.. currentmodule:: pina.weighting.weighting_interface
+
+.. automodule:: pina._src.weighting.weighting_interface
+
+.. autoclass:: pina._src.weighting.weighting_interface.WeightingInterface
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/pina/__init__.py b/pina/__init__.py
index 2cbe7f3bb..cafab2d31 100644
--- a/pina/__init__.py
+++ b/pina/__init__.py
@@ -1,18 +1,37 @@
-"""Module for the Pina library."""
+"""
+A specialized framework for Scientific Machine Learning (SciML), providing
+tools for Physics-Informed Neural Networks (PINNs), Neural Operators,
+and data-driven physical modeling.
+"""
__all__ = [
- "Trainer",
"LabelTensor",
+ "Trainer",
"Condition",
- "PinaDataModule",
+ "DataModule",
"Graph",
- "SolverInterface",
- "MultiSolverInterface",
]
-from .label_tensor import LabelTensor
-from .graph import Graph
-from .solver import SolverInterface, MultiSolverInterface
-from .trainer import Trainer
-from .condition.condition import Condition
-from .data import PinaDataModule
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.core.graph import Graph
+from pina._src.core.trainer import Trainer
+from pina._src.condition.condition import Condition
+from pina._src.data.data_module import DataModule
+
+# Back-compatibility with version 0.2, to be removed soon
+import warnings
+
+_DEPRECATED_IMPORTS = {"PinaDataModule": "DataModule"}
+
+
+def __getattr__(name):
+ if name in _DEPRECATED_IMPORTS:
+
+ warnings.warn(
+ f"Importing '{name}' from 'pina' is deprecated; use "
+ f"pina.{_DEPRECATED_IMPORTS[name]} instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ return globals()[_DEPRECATED_IMPORTS[name]]
diff --git a/pina/_src/__init__.py b/pina/_src/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pina/_src/adaptive_function/__init__.py b/pina/_src/adaptive_function/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pina/_src/adaptive_function/adaptive_celu.py b/pina/_src/adaptive_function/adaptive_celu.py
new file mode 100644
index 000000000..bb460933c
--- /dev/null
+++ b/pina/_src/adaptive_function/adaptive_celu.py
@@ -0,0 +1,77 @@
+"""Module for the Adaptive CELU activation function."""
+
+import torch
+from pina._src.adaptive_function.base_adaptive_function import (
+ BaseAdaptiveFunction,
+)
+
+
+class AdaptiveCELU(BaseAdaptiveFunction):
+ r"""
+ Adaptive, trainable variant of the :class:`~torch.nn.CELU` activation.
+
+ This module extends the standard CELU by introducing learnable scaling
+ and shifting parameters applied to both the input and the output.
+
+ Given the function :math:`\text{CELU}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
+ the corresponding adaptive activation
+ :math:`\text{CELU}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n` is
+ defined as:
+
+ .. math::
+ \text{CELU}_{\text{adaptive}}({x})=\alpha\,\text{CELU}(\beta{x}+\gamma),
+
+ where :math:`\alpha`, :math:`\beta`, and :math:`\gamma` are trainable
+ parameters controlling output scaling, input scaling, and input shifting,
+ respectively.
+
+ The CELU function is defined elementwise as:
+
+ .. math::
+ \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x) - 1))
+
+ .. seealso::
+
+ **Original reference**: Godfrey, L. B., Gashler, M. S. (2015).
+ *A continuum among logarithmic, linear, and exponential functions,
+ and its potential to improve generalization in neural networks.*
+ 7th international joint conference on knowledge discovery, knowledge
+ engineering and knowledge management (IC3K), Vol. 1.
+ DOI: `arXiv preprint arXiv:1602.01321.
+ `_.
+
+ **Original reference**: Jagtap, A. D., Karniadakis, G. E. (2020).
+ *Adaptive activation functions accelerate convergence in deep and
+ physics-informed neural networks*.
+ Journal of Computational Physics, 404.
+ DOI: `JCP 10.1016 `_.
+ """
+
+ def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
+ """
+ Initialization of the :class:`AdaptiveCELU` class.
+
+ :param alpha: The output scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type alpha: int | float
+ :param beta: The input scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type beta: int | float
+ :param gamma: The input shifting parameter of the adaptive function.
+ If ``None``, it is initialized to ``0``. Default is ``None``.
+ :type gamma: int | float
+ :param fixed: The names of parameters to keep fixed during training.
+ These parameters will not be optimized and will have
+ ``requires_grad=False``. Available options are ``"alpha"``,
+ ``"beta"``, and ``"gamma"``. If ``None``, all parameters are
+ trainable. Default is ``None``.
+ :type fixed: str | list[str]
+ :raises ValueError: If alpha, when provided, is not a number.
+ :raises ValueError: If beta, when provided, is not a number.
+ :raises ValueError: If gamma, when provided, is not a number.
+ :raises ValueError: If fixed, when provided, is neither a string nor a
+ list of strings.
+ :raises ValueError: If fixed contains invalid parameter names.
+ """
+ super().__init__(alpha, beta, gamma, fixed)
+ self._func = torch.nn.CELU()
diff --git a/pina/_src/adaptive_function/adaptive_elu.py b/pina/_src/adaptive_function/adaptive_elu.py
new file mode 100644
index 000000000..12b40fa46
--- /dev/null
+++ b/pina/_src/adaptive_function/adaptive_elu.py
@@ -0,0 +1,80 @@
+"""Module for the Adaptive ELU activation function."""
+
+import torch
+from pina._src.adaptive_function.base_adaptive_function import (
+ BaseAdaptiveFunction,
+)
+
+
+class AdaptiveELU(BaseAdaptiveFunction):
+ r"""
+ Adaptive, trainable variant of the :class:`~torch.nn.ELU` activation.
+
+ This module extends the standard ELU by introducing learnable scaling
+ and shifting parameters applied to both the input and the output.
+
+ Given the function :math:`\text{ELU}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
+ the corresponding adaptive activation
+ :math:`\text{ELU}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n` is
+ defined as:
+
+ .. math::
+ \text{ELU}_{\text{adaptive}}({x}) = \alpha\,\text{ELU}(\beta{x}+\gamma),
+
+ where :math:`\alpha`, :math:`\beta`, and :math:`\gamma` are trainable
+ parameters controlling output scaling, input scaling, and input shifting,
+ respectively.
+
+ The ELU function is defined elementwise as:
+
+ .. math::
+ \text{ELU}(x) = \begin{cases}
+ x, & \text{ if }x > 0\\
+ \exp(x) - 1, & \text{ if }x \leq 0
+ \end{cases}
+
+ .. seealso::
+
+ **Original reference**: Godfrey, L. B., Gashler, M. S. (2015).
+ *A continuum among logarithmic, linear, and exponential functions,
+ and its potential to improve generalization in neural networks.*
+ 7th international joint conference on knowledge discovery, knowledge
+ engineering and knowledge management (IC3K), Vol. 1.
+ DOI: `arXiv preprint arXiv:1602.01321.
+ `_.
+
+ **Original reference**: Jagtap, A. D., Karniadakis, G. E. (2020).
+ *Adaptive activation functions accelerate convergence in deep and
+ physics-informed neural networks*.
+ Journal of Computational Physics, 404.
+ DOI: `JCP 10.1016 `_.
+ """
+
+ def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
+ """
+ Initialization of the :class:`AdaptiveELU` class.
+
+ :param alpha: The output scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type alpha: int | float
+ :param beta: The input scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type beta: int | float
+ :param gamma: The input shifting parameter of the adaptive function.
+ If ``None``, it is initialized to ``0``. Default is ``None``.
+ :type gamma: int | float
+ :param fixed: The names of parameters to keep fixed during training.
+ These parameters will not be optimized and will have
+ ``requires_grad=False``. Available options are ``"alpha"``,
+ ``"beta"``, and ``"gamma"``. If ``None``, all parameters are
+ trainable. Default is ``None``.
+ :type fixed: str | list[str]
+ :raises ValueError: If alpha, when provided, is not a number.
+ :raises ValueError: If beta, when provided, is not a number.
+ :raises ValueError: If gamma, when provided, is not a number.
+ :raises ValueError: If fixed, when provided, is neither a string nor a
+ list of strings.
+ :raises ValueError: If fixed contains invalid parameter names.
+ """
+ super().__init__(alpha, beta, gamma, fixed)
+ self._func = torch.nn.ELU()
diff --git a/pina/_src/adaptive_function/adaptive_exp.py b/pina/_src/adaptive_function/adaptive_exp.py
new file mode 100644
index 000000000..c6484f8c9
--- /dev/null
+++ b/pina/_src/adaptive_function/adaptive_exp.py
@@ -0,0 +1,73 @@
+"""Module for the Adaptive Exp activation function."""
+
+import torch
+from pina._src.adaptive_function.base_adaptive_function import (
+ BaseAdaptiveFunction,
+)
+
+
+class AdaptiveExp(BaseAdaptiveFunction):
+ r"""
+ Adaptive, trainable variant of the :obj:`~torch.exp` activation.
+
+ This module extends the standard exponential function by introducing
+ learnable scaling and shifting parameters applied to both the input and the
+ output.
+
+ Given the function :math:`\text{exp}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
+ the corresponding adaptive activation
+ :math:`\text{exp}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
+ is defined as:
+
+ .. math::
+ \text{exp}_{\text{adaptive}}({x}) = \alpha\,\text{exp}(\beta{x}+\gamma),
+
+ where :math:`\alpha`, :math:`\beta`, and :math:`\gamma` are trainable
+ parameters controlling output scaling, input scaling, and input shifting,
+ respectively.
+
+ .. seealso::
+
+ **Original reference**: Godfrey, L. B., Gashler, M. S. (2015).
+ *A continuum among logarithmic, linear, and exponential functions,
+ and its potential to improve generalization in neural networks.*
+ 7th international joint conference on knowledge discovery, knowledge
+ engineering and knowledge management (IC3K), Vol. 1.
+ DOI: `arXiv preprint arXiv:1602.01321.
+ `_.
+
+ **Original reference**: Jagtap, A. D., Karniadakis, G. E. (2020).
+ *Adaptive activation functions accelerate convergence in deep and
+ physics-informed neural networks*.
+ Journal of Computational Physics, 404.
+ DOI: `JCP 10.1016 `_.
+ """
+
+ def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
+ """
+ Initialization of the :class:`AdaptiveExp` class.
+
+ :param alpha: The output scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type alpha: int | float
+ :param beta: The input scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type beta: int | float
+ :param gamma: The input shifting parameter of the adaptive function.
+ If ``None``, it is initialized to ``0``. Default is ``None``.
+ :type gamma: int | float
+ :param fixed: The names of parameters to keep fixed during training.
+ These parameters will not be optimized and will have
+ ``requires_grad=False``. Available options are ``"alpha"``,
+ ``"beta"``, and ``"gamma"``. If ``None``, all parameters are
+ trainable. Default is ``None``.
+ :type fixed: str | list[str]
+ :raises ValueError: If alpha, when provided, is not a number.
+ :raises ValueError: If beta, when provided, is not a number.
+ :raises ValueError: If gamma, when provided, is not a number.
+ :raises ValueError: If fixed, when provided, is neither a string nor a
+ list of strings.
+ :raises ValueError: If fixed contains invalid parameter names.
+ """
+ super().__init__(alpha, beta, gamma, fixed)
+ self._func = torch.exp
diff --git a/pina/_src/adaptive_function/adaptive_function_interface.py b/pina/_src/adaptive_function/adaptive_function_interface.py
new file mode 100644
index 000000000..d53694bcd
--- /dev/null
+++ b/pina/_src/adaptive_function/adaptive_function_interface.py
@@ -0,0 +1,50 @@
+"""Module for the Adaptive Function Interface."""
+
+from abc import ABCMeta, abstractmethod
+
+
+class AdaptiveFunctionInterface(metaclass=ABCMeta):
+ """
+ Abstract interface for all adaptive functions.
+ """
+
+ @abstractmethod
+ def forward(self, x):
+ """
+ Compute the transformation of the adaptive function on the input.
+
+ :param x: The input tensor to evaluate the adaptive function.
+ :type x: torch.Tensor | LabelTensor
+ :return: The output of the adaptive function.
+ :rtype: torch.Tensor | LabelTensor
+ """
+
+ @property
+ @abstractmethod
+ def alpha(self):
+ """
+ The output scaling parameter of the adaptive function.
+
+ :return: The alpha parameter.
+ :rtype: torch.nn.Parameter | torch.Tensor
+ """
+
+ @property
+ @abstractmethod
+ def beta(self):
+ """
+ The input scaling parameter of the adaptive function.
+
+ :return: The beta parameter.
+ :rtype: torch.nn.Parameter | torch.Tensor
+ """
+
+ @property
+ @abstractmethod
+ def gamma(self):
+ """
+ The input shifting parameter of the adaptive function.
+
+ :return: The gamma parameter.
+ :rtype: torch.nn.Parameter | torch.Tensor
+ """
diff --git a/pina/_src/adaptive_function/adaptive_gelu.py b/pina/_src/adaptive_function/adaptive_gelu.py
new file mode 100644
index 000000000..148d43d52
--- /dev/null
+++ b/pina/_src/adaptive_function/adaptive_gelu.py
@@ -0,0 +1,78 @@
+"""Module for the Adaptive GELU activation function."""
+
+import torch
+from pina._src.adaptive_function.base_adaptive_function import (
+ BaseAdaptiveFunction,
+)
+
+
+class AdaptiveGELU(BaseAdaptiveFunction):
+ r"""
+ Adaptive, trainable variant of the :class:`~torch.nn.GELU` activation.
+
+ This module extends the standard GELU by introducing learnable scaling
+ and shifting parameters applied to both the input and the output.
+
+ Given the function :math:`\text{GELU}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
+ the corresponding adaptive activation
+ :math:`\text{GELU}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n` is
+ defined as:
+
+ .. math::
+ \text{GELU}_{\text{adaptive}}({x})=\alpha\,\text{GELU}(\beta{x}+\gamma),
+
+ where :math:`\alpha`, :math:`\beta`, and :math:`\gamma` are trainable
+ parameters controlling output scaling, input scaling, and input shifting,
+ respectively.
+
+ The GELU function is defined elementwise as:
+
+ .. math::
+ \text{GELU}(x)=0.5*x*(1+\text{Tanh}(\sqrt{2 / \pi}*(x+0.044715*x^3)))
+
+
+ .. seealso::
+
+ **Original reference**: Godfrey, L. B., Gashler, M. S. (2015).
+ *A continuum among logarithmic, linear, and exponential functions,
+ and its potential to improve generalization in neural networks.*
+ 7th international joint conference on knowledge discovery, knowledge
+ engineering and knowledge management (IC3K), Vol. 1.
+ DOI: `arXiv preprint arXiv:1602.01321.
+ `_.
+
+ **Original reference**: Jagtap, A. D., Karniadakis, G. E. (2020).
+ *Adaptive activation functions accelerate convergence in deep and
+ physics-informed neural networks*.
+ Journal of Computational Physics, 404.
+ DOI: `JCP 10.1016 `_.
+ """
+
+ def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
+ """
+ Initialization of the :class:`AdaptiveGELU` class.
+
+ :param alpha: The output scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type alpha: int | float
+ :param beta: The input scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type beta: int | float
+ :param gamma: The input shifting parameter of the adaptive function.
+ If ``None``, it is initialized to ``0``. Default is ``None``.
+ :type gamma: int | float
+ :param fixed: The names of parameters to keep fixed during training.
+ These parameters will not be optimized and will have
+ ``requires_grad=False``. Available options are ``"alpha"``,
+ ``"beta"``, and ``"gamma"``. If ``None``, all parameters are
+ trainable. Default is ``None``.
+ :type fixed: str | list[str]
+ :raises ValueError: If alpha, when provided, is not a number.
+ :raises ValueError: If beta, when provided, is not a number.
+ :raises ValueError: If gamma, when provided, is not a number.
+ :raises ValueError: If fixed, when provided, is neither a string nor a
+ list of strings.
+ :raises ValueError: If fixed contains invalid parameter names.
+ """
+ super().__init__(alpha, beta, gamma, fixed)
+ self._func = torch.nn.GELU()
diff --git a/pina/_src/adaptive_function/adaptive_mish.py b/pina/_src/adaptive_function/adaptive_mish.py
new file mode 100644
index 000000000..1c7278a1e
--- /dev/null
+++ b/pina/_src/adaptive_function/adaptive_mish.py
@@ -0,0 +1,77 @@
+"""Module for the Adaptive Mish activation function."""
+
+import torch
+from pina._src.adaptive_function.base_adaptive_function import (
+ BaseAdaptiveFunction,
+)
+
+
+class AdaptiveMish(BaseAdaptiveFunction):
+ r"""
+ Adaptive, trainable variant of the :class:`~torch.nn.Mish` activation.
+
+ This module extends the standard Mish by introducing learnable scaling
+ and shifting parameters applied to both the input and the output.
+
+ Given the function :math:`\text{Mish}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
+ the corresponding adaptive activation
+ :math:`\text{Mish}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n` is
+ defined as:
+
+ .. math::
+ \text{Mish}_{\text{adaptive}}({x})=\alpha\,\text{Mish}(\beta{x}+\gamma),
+
+ where :math:`\alpha`, :math:`\beta`, and :math:`\gamma` are trainable
+ parameters controlling output scaling, input scaling, and input shifting,
+ respectively.
+
+ The Mish function is defined elementwise as:
+
+ .. math::
+ \text{Mish}(x) = x * \text{Tanh}(x)
+
+ .. seealso::
+
+ **Original reference**: Godfrey, L. B., Gashler, M. S. (2015).
+ *A continuum among logarithmic, linear, and exponential functions,
+ and its potential to improve generalization in neural networks.*
+ 7th international joint conference on knowledge discovery, knowledge
+ engineering and knowledge management (IC3K), Vol. 1.
+ DOI: `arXiv preprint arXiv:1602.01321.
+ `_.
+
+ **Original reference**: Jagtap, A. D., Karniadakis, G. E. (2020).
+ *Adaptive activation functions accelerate convergence in deep and
+ physics-informed neural networks*.
+ Journal of Computational Physics, 404.
+ DOI: `JCP 10.1016 `_.
+ """
+
+ def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
+ """
+ Initialization of the :class:`AdaptiveMish` class.
+
+ :param alpha: The output scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type alpha: int | float
+ :param beta: The input scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type beta: int | float
+ :param gamma: The input shifting parameter of the adaptive function.
+ If ``None``, it is initialized to ``0``. Default is ``None``.
+ :type gamma: int | float
+ :param fixed: The names of parameters to keep fixed during training.
+ These parameters will not be optimized and will have
+ ``requires_grad=False``. Available options are ``"alpha"``,
+ ``"beta"``, and ``"gamma"``. If ``None``, all parameters are
+ trainable. Default is ``None``.
+ :type fixed: str | list[str]
+ :raises ValueError: If alpha, when provided, is not a number.
+ :raises ValueError: If beta, when provided, is not a number.
+ :raises ValueError: If gamma, when provided, is not a number.
+ :raises ValueError: If fixed, when provided, is neither a string nor a
+ list of strings.
+ :raises ValueError: If fixed contains invalid parameter names.
+ """
+ super().__init__(alpha, beta, gamma, fixed)
+ self._func = torch.nn.Mish()
diff --git a/pina/_src/adaptive_function/adaptive_relu.py b/pina/_src/adaptive_function/adaptive_relu.py
new file mode 100644
index 000000000..bd8ec0879
--- /dev/null
+++ b/pina/_src/adaptive_function/adaptive_relu.py
@@ -0,0 +1,78 @@
+"""Module for the Adaptive ReLU activation function."""
+
+import torch
+from pina._src.adaptive_function.base_adaptive_function import (
+ BaseAdaptiveFunction,
+)
+
+
+class AdaptiveReLU(BaseAdaptiveFunction):
+ r"""
+ Adaptive, trainable variant of the :class:`torch.nn.ReLU` activation.
+
+ This module extends the standard ReLU by introducing learnable scaling
+ and shifting parameters applied to both the input and the output.
+
+ Given the function :math:`\text{ReLU}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
+ the corresponding adaptive activation
+ :math:`\text{ReLU}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n` is
+ defined as:
+
+ .. math::
+ \text{ReLU}_{\text{adaptive}}(x) =
+ \alpha \, \text{ReLU}(\beta x + \gamma),
+
+ where :math:`\alpha`, :math:`\beta`, and :math:`\gamma` are trainable
+ parameters controlling output scaling, input scaling, and input shifting,
+ respectively.
+
+ The ReLU function is defined elementwise as:
+
+ .. math::
+ \text{ReLU}(x) = \max(0, x).
+
+ .. seealso::
+
+ **Original reference**: Godfrey, L. B., Gashler, M. S. (2015).
+ *A continuum among logarithmic, linear, and exponential functions,
+ and its potential to improve generalization in neural networks.*
+ 7th international joint conference on knowledge discovery, knowledge
+ engineering and knowledge management (IC3K), Vol. 1.
+ DOI: `arXiv preprint arXiv:1602.01321.
+ `_.
+
+ **Original reference**: Jagtap, A. D., Karniadakis, G. E. (2020).
+ *Adaptive activation functions accelerate convergence in deep and
+ physics-informed neural networks*.
+ Journal of Computational Physics, 404.
+ DOI: `JCP 10.1016 `_.
+ """
+
+ def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
+ """
+ Initialization of the :class:`AdaptiveReLU` class.
+
+ :param alpha: The output scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type alpha: int | float
+ :param beta: The input scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type beta: int | float
+ :param gamma: The input shifting parameter of the adaptive function.
+ If ``None``, it is initialized to ``0``. Default is ``None``.
+ :type gamma: int | float
+ :param fixed: The names of parameters to keep fixed during training.
+ These parameters will not be optimized and will have
+ ``requires_grad=False``. Available options are ``"alpha"``,
+ ``"beta"``, and ``"gamma"``. If ``None``, all parameters are
+ trainable. Default is ``None``.
+ :type fixed: str | list[str]
+ :raises ValueError: If alpha, when provided, is not a number.
+ :raises ValueError: If beta, when provided, is not a number.
+ :raises ValueError: If gamma, when provided, is not a number.
+ :raises ValueError: If fixed, when provided, is neither a string nor a
+ list of strings.
+ :raises ValueError: If fixed contains invalid parameter names.
+ """
+ super().__init__(alpha, beta, gamma, fixed)
+ self._func = torch.nn.ReLU()
diff --git a/pina/_src/adaptive_function/adaptive_sigmoid.py b/pina/_src/adaptive_function/adaptive_sigmoid.py
new file mode 100644
index 000000000..c88eafab2
--- /dev/null
+++ b/pina/_src/adaptive_function/adaptive_sigmoid.py
@@ -0,0 +1,79 @@
+"""Module for the Adaptive Sigmoid activation function."""
+
+import torch
+from pina._src.adaptive_function.base_adaptive_function import (
+ BaseAdaptiveFunction,
+)
+
+
+class AdaptiveSigmoid(BaseAdaptiveFunction):
+ r"""
+ Adaptive, trainable variant of the :class:`~torch.nn.Sigmoid` activation.
+
+ This module extends the standard Sigmoid by introducing learnable scaling
+ and shifting parameters applied to both the input and the output.
+
+ Given the function
+ :math:`\text{Sigmoid}:\mathbb{R}^n\rightarrow\mathbb{R}^n`, the
+ corresponding adaptive activation
+ :math:`\text{Sigmoid}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
+ is defined as:
+
+ .. math::
+ \text{Sigmoid}_{\text{adaptive}}({x})=
+ \alpha\,\text{Sigmoid}(\beta{x}+\gamma),
+
+ where :math:`\alpha`, :math:`\beta`, and :math:`\gamma` are trainable
+ parameters controlling output scaling, input scaling, and input shifting,
+ respectively.
+
+ The Sigmoid function is defined elementwise as:
+
+ .. math::
+ \text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)}
+
+ .. seealso::
+
+ **Original reference**: Godfrey, L. B., Gashler, M. S. (2015).
+ *A continuum among logarithmic, linear, and exponential functions,
+ and its potential to improve generalization in neural networks.*
+ 7th international joint conference on knowledge discovery, knowledge
+ engineering and knowledge management (IC3K), Vol. 1.
+ DOI: `arXiv preprint arXiv:1602.01321.
+ `_.
+
+ **Original reference**: Jagtap, A. D., Karniadakis, G. E. (2020).
+ *Adaptive activation functions accelerate convergence in deep and
+ physics-informed neural networks*.
+ Journal of Computational Physics, 404.
+ DOI: `JCP 10.1016 `_.
+ """
+
+ def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
+ """
+ Initialization of the :class:`AdaptiveSigmoid` class.
+
+ :param alpha: The output scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type alpha: int | float
+ :param beta: The input scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type beta: int | float
+ :param gamma: The input shifting parameter of the adaptive function.
+ If ``None``, it is initialized to ``0``. Default is ``None``.
+ :type gamma: int | float
+ :param fixed: The names of parameters to keep fixed during training.
+ These parameters will not be optimized and will have
+ ``requires_grad=False``. Available options are ``"alpha"``,
+ ``"beta"``, and ``"gamma"``. If ``None``, all parameters are
+ trainable. Default is ``None``.
+ :type fixed: str | list[str]
+ :raises ValueError: If alpha, when provided, is not a number.
+ :raises ValueError: If beta, when provided, is not a number.
+ :raises ValueError: If gamma, when provided, is not a number.
+ :raises ValueError: If fixed, when provided, is neither a string nor a
+ list of strings.
+ :raises ValueError: If fixed contains invalid parameter names.
+ """
+ super().__init__(alpha, beta, gamma, fixed)
+ self._func = torch.nn.Sigmoid()
diff --git a/pina/_src/adaptive_function/adaptive_silu.py b/pina/_src/adaptive_function/adaptive_silu.py
new file mode 100644
index 000000000..d35b867a6
--- /dev/null
+++ b/pina/_src/adaptive_function/adaptive_silu.py
@@ -0,0 +1,79 @@
+"""Module for the Adaptive SiLU activation function."""
+
+import torch
+from pina._src.adaptive_function.base_adaptive_function import (
+ BaseAdaptiveFunction,
+)
+
+
+class AdaptiveSiLU(BaseAdaptiveFunction):
+ r"""
+ Adaptive, trainable variant of the :class:`~torch.nn.SiLU` activation.
+
+ This module extends the standard SiLU by introducing learnable scaling
+ and shifting parameters applied to both the input and the output.
+
+ Given the function :math:`\text{SiLU}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
+ the corresponding adaptive activation
+ :math:`\text{SiLU}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n` is
+ defined as:
+
+ .. math::
+ \text{SiLU}_{\text{adaptive}}({x})=\alpha\,\text{SiLU}(\beta{x}+\gamma),
+
+ where :math:`\alpha`, :math:`\beta`, and :math:`\gamma` are trainable
+ parameters controlling output scaling, input scaling, and input shifting,
+ respectively.
+
+ The SiLU function is defined elementwise as:
+
+ .. math::
+ \text{SiLU}(x) = x * \sigma(x),
+
+ where :math:`\sigma(x)` is the logistic sigmoid function.
+
+ .. seealso::
+
+ **Original reference**: Godfrey, L. B., Gashler, M. S. (2015).
+ *A continuum among logarithmic, linear, and exponential functions,
+ and its potential to improve generalization in neural networks.*
+ 7th international joint conference on knowledge discovery, knowledge
+ engineering and knowledge management (IC3K), Vol. 1.
+ DOI: `arXiv preprint arXiv:1602.01321.
+ `_.
+
+ **Original reference**: Jagtap, A. D., Karniadakis, G. E. (2020).
+ *Adaptive activation functions accelerate convergence in deep and
+ physics-informed neural networks*.
+ Journal of Computational Physics, 404.
+ DOI: `JCP 10.1016 `_.
+ """
+
+ def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
+ """
+ Initialization of the :class:`AdaptiveSiLU` class.
+
+ :param alpha: The output scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type alpha: int | float
+ :param beta: The input scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type beta: int | float
+ :param gamma: The input shifting parameter of the adaptive function.
+ If ``None``, it is initialized to ``0``. Default is ``None``.
+ :type gamma: int | float
+ :param fixed: The names of parameters to keep fixed during training.
+ These parameters will not be optimized and will have
+ ``requires_grad=False``. Available options are ``"alpha"``,
+ ``"beta"``, and ``"gamma"``. If ``None``, all parameters are
+ trainable. Default is ``None``.
+ :type fixed: str | list[str]
+ :raises ValueError: If alpha, when provided, is not a number.
+ :raises ValueError: If beta, when provided, is not a number.
+ :raises ValueError: If gamma, when provided, is not a number.
+ :raises ValueError: If fixed, when provided, is neither a string nor a
+ list of strings.
+ :raises ValueError: If fixed contains invalid parameter names.
+ """
+ super().__init__(alpha, beta, gamma, fixed)
+ self._func = torch.nn.SiLU()
diff --git a/pina/_src/adaptive_function/adaptive_siren.py b/pina/_src/adaptive_function/adaptive_siren.py
new file mode 100644
index 000000000..dfb42b4b9
--- /dev/null
+++ b/pina/_src/adaptive_function/adaptive_siren.py
@@ -0,0 +1,72 @@
+"""Module for the Adaptive SIREN activation function."""
+
+import torch
+from pina._src.adaptive_function.base_adaptive_function import (
+ BaseAdaptiveFunction,
+)
+
+
+class AdaptiveSIREN(BaseAdaptiveFunction):
+ r"""
+ Adaptive, trainable variant of the :obj:`~torch.sin` activation.
+
+ This module extends the standard SIREN by introducing learnable scaling
+ and shifting parameters applied to both the input and the output.
+
+ Given the function :math:`\text{sin}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
+ the corresponding adaptive activation
+ :math:`\text{sin}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
+ is defined as:
+
+ .. math::
+ \text{sin}_{\text{adaptive}}({x}) = \alpha\,\text{sin}(\beta{x}+\gamma),
+
+ where :math:`\alpha`, :math:`\beta`, and :math:`\gamma` are trainable
+ parameters controlling output scaling, input scaling, and input shifting,
+ respectively.
+
+ .. seealso::
+
+ **Original reference**: Godfrey, L. B., Gashler, M. S. (2015).
+ *A continuum among logarithmic, linear, and exponential functions,
+ and its potential to improve generalization in neural networks.*
+ 7th international joint conference on knowledge discovery, knowledge
+ engineering and knowledge management (IC3K), Vol. 1.
+ DOI: `arXiv preprint arXiv:1602.01321.
+ `_.
+
+ **Original reference**: Jagtap, A. D., Karniadakis, G. E. (2020).
+ *Adaptive activation functions accelerate convergence in deep and
+ physics-informed neural networks*.
+ Journal of Computational Physics, 404.
+ DOI: `JCP 10.1016 `_.
+ """
+
+ def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
+ """
+ Initialization of the :class:`AdaptiveSIREN` class.
+
+ :param alpha: The output scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type alpha: int | float
+ :param beta: The input scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type beta: int | float
+ :param gamma: The input shifting parameter of the adaptive function.
+ If ``None``, it is initialized to ``0``. Default is ``None``.
+ :type gamma: int | float
+ :param fixed: The names of parameters to keep fixed during training.
+ These parameters will not be optimized and will have
+ ``requires_grad=False``. Available options are ``"alpha"``,
+ ``"beta"``, and ``"gamma"``. If ``None``, all parameters are
+ trainable. Default is ``None``.
+ :type fixed: str | list[str]
+ :raises ValueError: If alpha, when provided, is not a number.
+ :raises ValueError: If beta, when provided, is not a number.
+ :raises ValueError: If gamma, when provided, is not a number.
+ :raises ValueError: If fixed, when provided, is neither a string nor a
+ list of strings.
+ :raises ValueError: If fixed contains invalid parameter names.
+ """
+ super().__init__(alpha, beta, gamma, fixed)
+ self._func = torch.sin
diff --git a/pina/_src/adaptive_function/adaptive_softmax.py b/pina/_src/adaptive_function/adaptive_softmax.py
new file mode 100644
index 000000000..7f2ad156f
--- /dev/null
+++ b/pina/_src/adaptive_function/adaptive_softmax.py
@@ -0,0 +1,79 @@
+"""Module for the Adaptive Softmax activation function."""
+
+import torch
+from pina._src.adaptive_function.base_adaptive_function import (
+ BaseAdaptiveFunction,
+)
+
+
+class AdaptiveSoftmax(BaseAdaptiveFunction):
+ r"""
+ Adaptive, trainable variant of the :class:`~torch.nn.Softmax` activation.
+
+ This module extends the standard Softmax by introducing learnable scaling
+ and shifting parameters applied to both the input and the output.
+
+ Given the function
+ :math:`\text{Softmax}:\mathbb{R}^n\rightarrow\mathbb{R}^n`, the
+ corresponding adaptive activation
+ :math:`\text{Softmax}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
+ is defined as:
+
+ .. math::
+ \text{Softmax}_{\text{adaptive}}({x})=\alpha\,
+ \text{Softmax}(\beta{x}+\gamma),
+
+ where :math:`\alpha`, :math:`\beta`, and :math:`\gamma` are trainable
+ parameters controlling output scaling, input scaling, and input shifting,
+ respectively.
+
+ The Softmax function is defined elementwise as:
+
+ .. math::
+ \text{Softmax}(x_i) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
+
+ .. seealso::
+
+ **Original reference**: Godfrey, L. B., Gashler, M. S. (2015).
+ *A continuum among logarithmic, linear, and exponential functions,
+ and its potential to improve generalization in neural networks.*
+ 7th international joint conference on knowledge discovery, knowledge
+ engineering and knowledge management (IC3K), Vol. 1.
+ DOI: `arXiv preprint arXiv:1602.01321.
+ `_.
+
+ **Original reference**: Jagtap, A. D., Karniadakis, G. E. (2020).
+ *Adaptive activation functions accelerate convergence in deep and
+ physics-informed neural networks*.
+ Journal of Computational Physics, 404.
+ DOI: `JCP 10.1016 `_.
+ """
+
+ def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
+ """
+ Initialization of the :class:`AdaptiveSoftmax` class.
+
+ :param alpha: The output scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type alpha: int | float
+ :param beta: The input scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type beta: int | float
+ :param gamma: The input shifting parameter of the adaptive function.
+ If ``None``, it is initialized to ``0``. Default is ``None``.
+ :type gamma: int | float
+ :param fixed: The names of parameters to keep fixed during training.
+ These parameters will not be optimized and will have
+ ``requires_grad=False``. Available options are ``"alpha"``,
+ ``"beta"``, and ``"gamma"``. If ``None``, all parameters are
+ trainable. Default is ``None``.
+ :type fixed: str | list[str]
+ :raises ValueError: If alpha, when provided, is not a number.
+ :raises ValueError: If beta, when provided, is not a number.
+ :raises ValueError: If gamma, when provided, is not a number.
+ :raises ValueError: If fixed, when provided, is neither a string nor a
+ list of strings.
+ :raises ValueError: If fixed contains invalid parameter names.
+ """
+ super().__init__(alpha, beta, gamma, fixed)
+ self._func = torch.nn.Softmax(dim=-1)
diff --git a/pina/_src/adaptive_function/adaptive_softmin.py b/pina/_src/adaptive_function/adaptive_softmin.py
new file mode 100644
index 000000000..b07e27bbf
--- /dev/null
+++ b/pina/_src/adaptive_function/adaptive_softmin.py
@@ -0,0 +1,79 @@
+"""Module for the Adaptive Softmin activation function."""
+
+import torch
+from pina._src.adaptive_function.base_adaptive_function import (
+ BaseAdaptiveFunction,
+)
+
+
+class AdaptiveSoftmin(BaseAdaptiveFunction):
+ r"""
+ Adaptive, trainable variant of the :class:`~torch.nn.Softmin` activation.
+
+ This module extends the standard Softmin by introducing learnable scaling
+ and shifting parameters applied to both the input and the output.
+
+ Given the function
+ :math:`\text{Softmin}:\mathbb{R}^n\rightarrow\mathbb{R}^n`, the
+ corresponding adaptive activation
+ :math:`\text{Softmin}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
+ is defined as:
+
+ .. math::
+ \text{Softmin}_{\text{adaptive}}({x})=\alpha\,
+ \text{Softmin}(\beta{x}+\gamma),
+
+ where :math:`\alpha`, :math:`\beta`, and :math:`\gamma` are trainable
+ parameters controlling output scaling, input scaling, and input shifting,
+ respectively.
+
+ The Softmin function is defined elementwise as:
+
+ .. math::
+ \text{Softmin}(x_i) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
+
+ .. seealso::
+
+ **Original reference**: Godfrey, L. B., Gashler, M. S. (2015).
+ *A continuum among logarithmic, linear, and exponential functions,
+ and its potential to improve generalization in neural networks.*
+ 7th international joint conference on knowledge discovery, knowledge
+ engineering and knowledge management (IC3K), Vol. 1.
+ DOI: `arXiv preprint arXiv:1602.01321.
+ `_.
+
+ **Original reference**: Jagtap, A. D., Karniadakis, G. E. (2020).
+ *Adaptive activation functions accelerate convergence in deep and
+ physics-informed neural networks*.
+ Journal of Computational Physics, 404.
+ DOI: `JCP 10.1016 `_.
+ """
+
+ def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
+ """
+ Initialization of the :class:`AdaptiveSoftmin` class.
+
+ :param alpha: The output scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type alpha: int | float
+ :param beta: The input scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type beta: int | float
+ :param gamma: The input shifting parameter of the adaptive function.
+ If ``None``, it is initialized to ``0``. Default is ``None``.
+ :type gamma: int | float
+ :param fixed: The names of parameters to keep fixed during training.
+ These parameters will not be optimized and will have
+ ``requires_grad=False``. Available options are ``"alpha"``,
+ ``"beta"``, and ``"gamma"``. If ``None``, all parameters are
+ trainable. Default is ``None``.
+ :type fixed: str | list[str]
+ :raises ValueError: If alpha, when provided, is not a number.
+ :raises ValueError: If beta, when provided, is not a number.
+ :raises ValueError: If gamma, when provided, is not a number.
+ :raises ValueError: If fixed, when provided, is neither a string nor a
+ list of strings.
+ :raises ValueError: If fixed contains invalid parameter names.
+ """
+ super().__init__(alpha, beta, gamma, fixed)
+ self._func = torch.nn.Softmin(dim=-1)
diff --git a/pina/_src/adaptive_function/adaptive_tanh.py b/pina/_src/adaptive_function/adaptive_tanh.py
new file mode 100644
index 000000000..513f4b3f0
--- /dev/null
+++ b/pina/_src/adaptive_function/adaptive_tanh.py
@@ -0,0 +1,72 @@
+"""Module for the Adaptive Tanh activation function."""
+
+import torch
+from pina._src.adaptive_function.base_adaptive_function import (
+ BaseAdaptiveFunction,
+)
+
+
+class AdaptiveTanh(BaseAdaptiveFunction):
+ r"""
+ Adaptive, trainable variant of the :class:`~torch.nn.Tanh` activation.
+
+ This module extends the standard Tanh by introducing learnable scaling
+ and shifting parameters applied to both the input and the output.
+
+ Given the function :math:`\text{Tanh}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
+ the corresponding adaptive activation
+ :math:`\text{Tanh}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n` is
+ defined as:
+
+ .. math::
+ \text{Tanh}_{\text{adaptive}}({x})=\alpha\,\text{Tanh}(\beta{x}+\gamma),
+
+ where :math:`\alpha`, :math:`\beta`, and :math:`\gamma` are trainable
+ parameters controlling output scaling, input scaling, and input shifting,
+ respectively.
+
+ .. seealso::
+
+ **Original reference**: Godfrey, L. B., Gashler, M. S. (2015).
+ *A continuum among logarithmic, linear, and exponential functions,
+ and its potential to improve generalization in neural networks.*
+ 7th international joint conference on knowledge discovery, knowledge
+ engineering and knowledge management (IC3K), Vol. 1.
+ DOI: `arXiv preprint arXiv:1602.01321.
+ `_.
+
+ **Original reference**: Jagtap, A. D., Karniadakis, G. E. (2020).
+ *Adaptive activation functions accelerate convergence in deep and
+ physics-informed neural networks*.
+ Journal of Computational Physics, 404.
+ DOI: `JCP 10.1016 `_.
+ """
+
+ def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
+ """
+ Initialization of the :class:`AdaptiveTanh` class.
+
+ :param alpha: The output scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type alpha: int | float
+ :param beta: The input scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type beta: int | float
+ :param gamma: The input shifting parameter of the adaptive function.
+ If ``None``, it is initialized to ``0``. Default is ``None``.
+ :type gamma: int | float
+ :param fixed: The names of parameters to keep fixed during training.
+ These parameters will not be optimized and will have
+ ``requires_grad=False``. Available options are ``"alpha"``,
+ ``"beta"``, and ``"gamma"``. If ``None``, all parameters are
+ trainable. Default is ``None``.
+ :type fixed: str | list[str]
+ :raises ValueError: If alpha, when provided, is not a number.
+ :raises ValueError: If beta, when provided, is not a number.
+ :raises ValueError: If gamma, when provided, is not a number.
+ :raises ValueError: If fixed, when provided, is neither a string nor a
+ list of strings.
+ :raises ValueError: If fixed contains invalid parameter names.
+ """
+ super().__init__(alpha, beta, gamma, fixed)
+ self._func = torch.nn.Tanh()
diff --git a/pina/_src/adaptive_function/base_adaptive_function.py b/pina/_src/adaptive_function/base_adaptive_function.py
new file mode 100644
index 000000000..c391d308a
--- /dev/null
+++ b/pina/_src/adaptive_function/base_adaptive_function.py
@@ -0,0 +1,178 @@
+"""Module for the Adaptive Function base class."""
+
+import torch
+from pina._src.core.utils import check_consistency
+from pina._src.adaptive_function.adaptive_function_interface import (
+ AdaptiveFunctionInterface,
+)
+
+
+class BaseAdaptiveFunction(torch.nn.Module, AdaptiveFunctionInterface):
+ r"""
+ Base class for all adaptive functions, implementing common functionality.
+
+ This class extends a standard :class:`torch.nn.Module` activation function
+ into a trainable adaptive form. It implements the common mechanism used to
+ scale and shift both the input and the output of a given activation
+ function.
+
+ Given a function :math:`f:\mathbb{R}^n\rightarrow\mathbb{R}^m`, the adaptive
+ function :math:`f_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^m`
+ is defined as:
+
+ .. math::
+ f_{\text{adaptive}}(\mathbf{x}) = \alpha\,f(\beta\mathbf{x}+\gamma),
+
+ where :math:`\alpha`, :math:`\beta`, and :math:`\gamma` are learnable
+ parameters controlling output scaling, input scaling, and input shifting,
+ respectively.
+
+ All specific adaptive functions should inherit from this class and implement
+ the abstract methods declared in the interface.
+
+ This class is not meant to be instantiated directly.
+
+ .. seealso::
+
+ **Original reference**: Godfrey, L. B., Gashler, M. S. (2015).
+ *A continuum among logarithmic, linear, and exponential functions,
+ and its potential to improve generalization in neural networks.*
+ 7th international joint conference on knowledge discovery, knowledge
+ engineering and knowledge management (IC3K), Vol. 1.
+ DOI: `arXiv preprint arXiv:1602.01321.
+ `_.
+
+ **Original reference**: Jagtap, A. D., Karniadakis, G. E. (2020).
+ *Adaptive activation functions accelerate convergence in deep and
+ physics-informed neural networks*.
+ Journal of Computational Physics, 404.
+ DOI: `JCP 10.1016 `_.
+ """
+
+ def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
+ """
+ Initialization of the :class:`BaseAdaptiveFunction` class.
+
+ :param alpha: The output scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type alpha: int | float
+ :param beta: The input scaling parameter of the adaptive function.
+ If ``None``, it is initialized to ``1``. Default is ``None``.
+ :type beta: int | float
+ :param gamma: The input shifting parameter of the adaptive function.
+ If ``None``, it is initialized to ``0``. Default is ``None``.
+ :type gamma: int | float
+ :param fixed: The names of parameters to keep fixed during training.
+ These parameters will not be optimized and will have
+ ``requires_grad=False``. Available options are ``"alpha"``,
+ ``"beta"``, and ``"gamma"``. If ``None``, all parameters are
+ trainable. Default is ``None``.
+ :type fixed: str | list[str]
+ :raises ValueError: If alpha, when provided, is not a number.
+ :raises ValueError: If beta, when provided, is not a number.
+ :raises ValueError: If gamma, when provided, is not a number.
+ :raises ValueError: If fixed, when provided, is neither a string nor a
+ list of strings.
+ :raises ValueError: If fixed contains invalid parameter names.
+ """
+ super().__init__()
+
+ # Set default values for alpha, beta, gamma if they are None
+ alpha = 1.0 if alpha is None else alpha
+ beta = 1.0 if beta is None else beta
+ gamma = 0.0 if gamma is None else gamma
+
+ # Check consistency
+ check_consistency(alpha, (int, float))
+ check_consistency(beta, (int, float))
+ check_consistency(gamma, (int, float))
+
+ # Process fixed parameters
+ if fixed is not None:
+ check_consistency(fixed, str)
+ fixed = {fixed} if isinstance(fixed, str) else set(fixed)
+ else:
+ fixed = set()
+
+ # Validate fixed parameter names
+ invalid_names = fixed - {"alpha", "beta", "gamma"}
+ if invalid_names:
+ raise ValueError(
+ f"Invalid fixed parameter name(s): {sorted(invalid_names)}. "
+ "Available options are 'alpha', 'beta', and 'gamma'."
+ )
+
+ # Register either a trainable parameter or a fixed buffer
+ def _register_adaptive_param(name, value):
+ """
+ Helper function to register an adaptive parameter as either a
+ trainable parameter or a fixed buffer, depending on whether it is
+ specified in the ``fixed`` argument.
+ """
+ # Convert value to tensor
+ tensor = torch.tensor(value, dtype=torch.float32)
+
+ # Register as buffer if fixed, otherwise as parameter
+ if name in fixed:
+ self.register_buffer(f"_{name}", tensor)
+ else:
+ setattr(self, f"_{name}", torch.nn.Parameter(tensor))
+
+ # Register parameters
+ _register_adaptive_param("alpha", alpha)
+ _register_adaptive_param("beta", beta)
+ _register_adaptive_param("gamma", gamma)
+
+ # Initialize the adaptive function to None, to be set by subclasses
+ self._func = None
+
+ def forward(self, x):
+ """
+ Compute the transformation of the adaptive function on the input.
+
+ :param x: The input tensor to evaluate the adaptive function.
+ :type x: torch.Tensor | LabelTensor
+ :raises RuntimeError: If the adaptive function has not been set.
+ :raises RuntimeError: If the adaptive function is not callable.
+ :return: The output of the adaptive function.
+ :rtype: torch.Tensor | LabelTensor
+ """
+ # Raise an error if the adaptive function has not been set
+ if self._func is None:
+ raise RuntimeError("The adaptive function has not been set.")
+
+ # Raise an error if the adaptive function is not callable
+ if not callable(self._func):
+ raise RuntimeError("The adaptive function is not callable.")
+
+ return self.alpha * (self._func(self.beta * x + self.gamma))
+
+ @property
+ def alpha(self):
+ """
+ The output scaling parameter of the adaptive function.
+
+ :return: The alpha parameter.
+ :rtype: torch.nn.Parameter | torch.Tensor
+ """
+ return self._alpha
+
+ @property
+ def beta(self):
+ """
+ The input scaling parameter of the adaptive function.
+
+ :return: The beta parameter.
+ :rtype: torch.nn.Parameter | torch.Tensor
+ """
+ return self._beta
+
+ @property
+ def gamma(self):
+ """
+ The input shifting parameter of the adaptive function.
+
+ :return: The gamma parameter.
+ :rtype: torch.nn.Parameter | torch.Tensor
+ """
+ return self._gamma
diff --git a/pina/_src/callback/__init__.py b/pina/_src/callback/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pina/_src/callback/optim/__init__.py b/pina/_src/callback/optim/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pina/callback/optim/switch_optimizer.py b/pina/_src/callback/optim/switch_optimizer.py
similarity index 57%
rename from pina/callback/optim/switch_optimizer.py
rename to pina/_src/callback/optim/switch_optimizer.py
index 3072b7c2e..36561fa28 100644
--- a/pina/callback/optim/switch_optimizer.py
+++ b/pina/_src/callback/optim/switch_optimizer.py
@@ -1,27 +1,36 @@
"""Module for the SwitchOptimizer callback."""
from lightning.pytorch.callbacks import Callback
-from ...optim import TorchOptimizer
-from ...utils import check_consistency
+from pina._src.optim.optimizer_interface import OptimizerInterface
+from pina._src.core.utils import check_consistency, check_positive_integer
class SwitchOptimizer(Callback):
"""
- PINA Implementation of a Lightning Callback to switch optimizer during
- training.
+ Lightning callback for dynamically replacing optimizers during training.
+
+ This callback enables switching to one or more new optimizers at a specified
+ epoch without restarting the training loop. It is particularly useful for
+ staged optimization strategies (e.g., coarse-to-fine training or optimizer
+ warm-up phases), where different optimizers are applied sequentially.
+
+ At the target epoch, the provided optimizers are hooked to the model
+ parameters and replace the current optimizers in both the PINA solver and
+ the Lightning trainer strategy.
"""
def __init__(self, new_optimizers, epoch_switch):
"""
- This callback allows switching between different optimizers during
- training, enabling the exploration of multiple optimization strategies
- without interrupting the training process.
+ Initialization of the :class:`SwitchOptimizer` class.
:param new_optimizers: The model optimizers to switch to. Can be a
single :class:`torch.optim.Optimizer` instance or a list of them
for multiple model solver.
- :type new_optimizers: pina.optim.TorchOptimizer | list
+ :type new_optimizers: pina.optim.OptimizerInterface | list
:param int epoch_switch: The epoch at which the optimizer switch occurs.
+ :raises AssertionError: If ``epoch_switch`` is not a positive integer.
+ :raises ValueError: If any of the provided optimizers are not instances
+ of :class:`pina.optim.OptimizerInterface`.
Example:
>>> optimizer = TorchOptimizer(torch.optim.Adam, lr=0.01)
@@ -31,19 +40,14 @@ def __init__(self, new_optimizers, epoch_switch):
"""
super().__init__()
- # Check if epoch_switch is greater than 1
- if epoch_switch < 1:
- raise ValueError("epoch_switch must be greater than one.")
+ # Check consistency
+ check_positive_integer(epoch_switch, strict=True)
+ check_consistency(new_optimizers, OptimizerInterface)
# If new_optimizers is not a list, convert it to a list
if not isinstance(new_optimizers, list):
new_optimizers = [new_optimizers]
- # Check consistency
- check_consistency(epoch_switch, int)
- for optimizer in new_optimizers:
- check_consistency(optimizer, TorchOptimizer)
-
# Store the new optimizers and epoch switch
self._new_optimizers = new_optimizers
self._epoch_switch = epoch_switch
@@ -52,9 +56,9 @@ def on_train_epoch_start(self, trainer, __):
"""
Switch the optimizer at the start of the specified training epoch.
- :param lightning.pytorch.Trainer trainer: The trainer object managing
- the training process.
- :param _: Placeholder argument (not used).
+ :param Trainer trainer: The trainer object managing the training
+ process.
+ :param __: Placeholder argument, not used.
"""
# Check if the current epoch matches the switch epoch
if trainer.current_epoch == self._epoch_switch:
diff --git a/pina/callback/optim/switch_scheduler.py b/pina/_src/callback/optim/switch_scheduler.py
similarity index 65%
rename from pina/callback/optim/switch_scheduler.py
rename to pina/_src/callback/optim/switch_scheduler.py
index 3641f4ee4..61284fb50 100644
--- a/pina/callback/optim/switch_scheduler.py
+++ b/pina/_src/callback/optim/switch_scheduler.py
@@ -1,30 +1,31 @@
"""Module for the SwitchScheduler callback."""
from lightning.pytorch.callbacks import Callback
-from ...optim import TorchScheduler
-from ...utils import check_consistency, check_positive_integer
+from pina._src.optim.scheduler_interface import SchedulerInterface
+from pina._src.core.utils import check_consistency, check_positive_integer
class SwitchScheduler(Callback):
"""
- Callback to switch scheduler during training.
+ Lightning callback for dynamically replacing schedulers during training.
+
+ This callback enables switching to new scheduler(s) at a specified epoch
+ without interrupting the training loop. It is useful for staged training
+ strategies where different learning rate policies are applied sequentially.
"""
def __init__(self, new_schedulers, epoch_switch):
"""
- This callback allows switching between different schedulers during
- training, enabling the exploration of multiple optimization strategies
- without interrupting the training process.
+ Initialization of the :class:`SwitchScheduler` class.
:param new_schedulers: The scheduler or list of schedulers to switch to.
Use a single scheduler for single-model solvers, or a list of
schedulers when working with multiple models.
- :type new_schedulers: pina.optim.TorchScheduler |
- list[pina.optim.TorchScheduler]
+ :type new_schedulers: SchedulerInterface | list[SchedulerInterface]
:param int epoch_switch: The epoch at which the scheduler switch occurs.
- :raise AssertionError: If epoch_switch is less than 1.
- :raise ValueError: If each scheduler in ``new_schedulers`` is not an
- instance of :class:`pina.optim.TorchScheduler`.
+ :raises AssertionError: If ``epoch_switch`` is not a positive integer.
+ :raises ValueError: If any of the provided schedulers are not instances
+ of :class:`pina.optim.SchedulerInterface`.
Example:
>>> scheduler = TorchScheduler(
@@ -36,17 +37,14 @@ def __init__(self, new_schedulers, epoch_switch):
"""
super().__init__()
- # Check if epoch_switch is greater than 1
- check_positive_integer(epoch_switch - 1, strict=True)
+ # Check consistency
+ check_positive_integer(epoch_switch, strict=True)
+ check_consistency(new_schedulers, SchedulerInterface)
# If new_schedulers is not a list, convert it to a list
if not isinstance(new_schedulers, list):
new_schedulers = [new_schedulers]
- # Check consistency
- for scheduler in new_schedulers:
- check_consistency(scheduler, TorchScheduler)
-
# Store the new schedulers and epoch switch
self._new_schedulers = new_schedulers
self._epoch_switch = epoch_switch
@@ -55,9 +53,9 @@ def on_train_epoch_start(self, trainer, __):
"""
Switch the scheduler at the start of the specified training epoch.
- :param lightning.pytorch.Trainer trainer: The trainer object managing
+ :param Trainer trainer: The trainer object managing
the training process.
- :param __: Placeholder argument (not used).
+ :param __: Placeholder argument, not used.
"""
# Check if the current epoch matches the switch epoch
if trainer.current_epoch == self._epoch_switch:
diff --git a/pina/_src/callback/processing/__init__.py b/pina/_src/callback/processing/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pina/_src/callback/processing/data_normalizer.py b/pina/_src/callback/processing/data_normalizer.py
new file mode 100644
index 000000000..515ed51d7
--- /dev/null
+++ b/pina/_src/callback/processing/data_normalizer.py
@@ -0,0 +1,206 @@
+"""Module for the Data Normalizer callback."""
+
+from typing import Callable
+import torch
+from lightning.pytorch import Callback
+from pina._src.core.utils import check_consistency
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.condition.condition import InputTargetCondition
+
+
+class DataNormalizer(Callback):
+ r"""
+ Callback for dataset normalization on input-target conditions.
+
+ This callback computes and applies a normalization transform to either
+ input or target tensors within a dataset. The transformation is defined as:
+
+ .. math::
+
+ x_{\text{norm}} = \frac{x - \mu}{\sigma},
+
+ where :math:`\mu` and :math:`\sigma` are computed using the provided
+ ``shift_fn`` and ``scale_fn`` functions, respectively. Normalization
+ parameters are estimated from the training dataset and then applied in-place
+ to the selected datasets depending on the chosen stage.
+
+ .. note::
+
+ This callback ignores all conditions that are not instances of
+ :class:`~pina.condition.InputTargetCondition`.
+
+ :Example:
+
+ >>> DataNormalizer(
+ ... scale_fn=torch.std,
+ ... shift_fn=torch.mean,
+ ... stage="all",
+ ... apply_to="input",
+ ... )
+ """
+
+ # Define valid options for stage and apply_to parameters
+ _VALID_STAGES = {"train", "validate", "test", "all"}
+ _VALID_APPLY_TO = {"input", "target"}
+
+ def __init__(
+ self,
+ scale_fn=torch.std,
+ shift_fn=torch.mean,
+ stage="all",
+ apply_to="input",
+ ):
+ """
+ Initialization of the :class:`DataNormalizer` class.
+
+ :param Callable scale_fn: The function used to compute the scaling
+ factor. Default is ``torch.std``.
+ :param Callable shift_fn: The function used to compute the shifting
+ factor. Default is ``torch.mean``.
+ :param str stage: The stage during which normalization is applied.
+ Available options are ``"train"``, ``"validate"``, ``"test"``, and
+ ``"all"``. Default is ``"all"``.
+ :param str apply_to: Specifies whether normalization is applied to
+ ``"input"`` or ``"target"`` tensors. Default is ``"input"``.
+ :raises ValueError: If ``scale_fn`` is not Callable.
+ :raises ValueError: If ``shift_fn`` is not Callable.
+ :raises ValueError: If ``stage`` is invalid.
+ :raises ValueError: If ``apply_to`` is invalid.
+ """
+ super().__init__()
+
+ # Check consistency
+ check_consistency(scale_fn, Callable)
+ check_consistency(shift_fn, Callable)
+ check_consistency(stage, str)
+ check_consistency(apply_to, str)
+
+ # Validate stage parameter
+ if stage not in self._VALID_STAGES:
+ raise ValueError(
+ "Invalid value for 'stage'. Available options are "
+ f"{self._VALID_STAGES}. Got {stage}."
+ )
+
+ # Validate apply_to parameter
+ if apply_to not in self._VALID_APPLY_TO:
+ raise ValueError(
+ "Invalid value for 'apply_to'. Available options are "
+ f"{self._VALID_APPLY_TO}. Got {apply_to}."
+ )
+
+ # Initialize attributes
+ self.scale_fn = scale_fn
+ self.shift_fn = shift_fn
+ self.stage = stage
+ self.apply_to = apply_to
+ self._normalizer = {}
+ self._normalized_conditions = set()
+
+ def setup(self, trainer, pl_module, stage):
+ """
+ Compute and apply normalization during the setup phase.
+
+ :param Trainer trainer: The trainer instance managing the execution.
+ :param BaseSolver pl_module: The solver module being executed.
+ :param str stage: Current execution stage.
+ :raises NotImplementedError: If the dataset is graph-based and
+ therefore unsupported.
+ """
+ # Check if any condition contains graph-based data
+ if any(
+ hasattr(ds.condition.data, "graph_key")
+ for ds in trainer.datamodule.train_datasets.values()
+ ):
+ raise NotImplementedError(
+ "DataNormalizer is not compatible with graph-based datasets."
+ )
+
+ # Extract input-target conditions
+ conditions_to_normalize = [
+ name
+ for name, cond in pl_module.problem.conditions.items()
+ if isinstance(cond, InputTargetCondition)
+ ]
+
+ # Extract the dictionary of all datasets
+ dataset = trainer.datamodule.train_datasets
+
+ # Compute scale and shift parameters if not already computed
+ if not self.normalizer:
+
+ # Iterate over conditions and compute normalization parameters
+ for cond in conditions_to_normalize:
+ pts = self._get_data(dataset, cond)
+ shift = self.shift_fn(pts)
+ scale = self.scale_fn(pts)
+
+ self._normalizer[cond] = {
+ "shift": shift,
+ "scale": scale,
+ }
+
+ # Apply normalization to training datasets
+ if stage == "fit" and self.stage in ["train", "all"]:
+ self.normalize_dataset(trainer.datamodule.train_datasets)
+
+ if stage == "fit" and self.stage in ["validate", "all"]:
+ self.normalize_dataset(trainer.datamodule.val_datasets)
+
+ if stage == "test" and self.stage in ["test", "all"]:
+ self.normalize_dataset(trainer.datamodule.test_datasets)
+
+ return super().setup(trainer, pl_module, stage)
+
+ def normalize_dataset(self, dataset):
+ """
+ Apply normalization to all datasets in-place.
+
+ Each condition is updated using precomputed normalization parameters.
+ The transformation preserves tensor types.
+
+ :param dict dataset: The mapping between condition names and their
+ associated dataset subsets.
+ """
+ # Iterate over conditions and apply normalization
+ for cond, norm_params in self.normalizer.items():
+ if cond in self._normalized_conditions:
+ continue
+
+ # Extract the points to normalize and the normalization parameters
+ data_container = getattr(dataset[cond].condition, self.apply_to)
+ points = data_container.data
+ scale = norm_params["scale"]
+ shift = norm_params["shift"]
+
+ # Apply normalization
+ scaled_pts = (points - shift) / scale
+ if isinstance(data_container, LabelTensor):
+ scaled_pts = LabelTensor(scaled_pts, data_container.labels)
+
+ # Update the dataset in-place
+ data_container.data = scaled_pts
+ self._normalized_conditions.add(cond)
+
+ def _get_data(self, dataset, cond):
+ """
+ Extract the selected data field from the dataset for a given condition.
+
+ :param dict dataset: The mapping between condition names and their
+ associated dataset subsets.
+ :param str cond: The condition name.
+ :return: The selected input or target data.
+ :rtype: torch.Tensor
+ """
+ return getattr(dataset[cond].condition, self.apply_to).data
+
+ @property
+ def normalizer(self):
+ """
+ The dictionary mapping each condition to its corresponding ``shift`` and
+ ``scale`` values.
+
+ :return: The dictionary of normalization parameters.
+ :rtype: dict
+ """
+ return self._normalizer
diff --git a/pina/_src/callback/processing/metric_tracker.py b/pina/_src/callback/processing/metric_tracker.py
new file mode 100644
index 000000000..360a5aacb
--- /dev/null
+++ b/pina/_src/callback/processing/metric_tracker.py
@@ -0,0 +1,107 @@
+"""Module for the Metric Tracker."""
+
+import copy
+import torch
+from lightning.pytorch.callbacks import Callback
+from pina._src.core.utils import check_consistency
+
+
+class MetricTracker(Callback):
+ """
+ Callback for collecting selected metrics logged during training.
+ """
+
+ def __init__(self, metrics_to_track=None):
+ """
+ Initialization of the :class:`MetricTracker` class.
+
+ :param metrics_to_track: The names of the metrics to collect. If
+ ``None``, defaults to ``["train_loss"]`` when no batch size is
+ available, otherwise to ``["train_loss_epoch"]``. Default is
+ ``None``.
+ :type metrics_to_track: str | list[str]
+ :raises ValueError: If any of the provided metric names are not strings.
+ """
+ super().__init__()
+
+ # Check consistency
+ if metrics_to_track is not None:
+ check_consistency(metrics_to_track, str)
+
+ # Convert to list if a single string is provided
+ if isinstance(metrics_to_track, str):
+ metrics_to_track = [metrics_to_track]
+
+ # Initialize the collection list and store the metrics to track
+ self.metrics_to_track = metrics_to_track
+ self._collection = []
+
+ def setup(self, trainer, pl_module, stage):
+ """
+ Configure the metrics to track before execution starts.
+
+ When a batch size is provided (i.e. ``trainer.batch_size`` is not
+ ``None``), metric names are expanded to match Lightning's logging
+ convention: for each metric ``m``, both ``m_step`` and ``m_epoch`` are
+ tracked. For example, ``"train_loss"`` becomes
+ ``["train_loss_step", "train_loss_epoch"]``.
+
+ :param Trainer trainer: The trainer instance managing the execution.
+ :param BaseSolver pl_module: The solver module being executed.
+ :param str stage: Current execution stage.
+ """
+ # Set default metrics to train_loss if no batch size is available
+ if self.metrics_to_track is None:
+ self.metrics_to_track = ["train_loss"]
+
+ # If a batch size is provided, expand metric names to match convention
+ if trainer.batch_size is not None:
+ self.metrics_to_track = [
+ f"{metric}_{suffix}"
+ for metric in self.metrics_to_track
+ for suffix in ("step", "epoch")
+ ]
+
+ return super().setup(trainer, pl_module, stage)
+
+ def on_train_epoch_end(self, trainer, __):
+ """
+ Store the selected logged metrics at the end of each training epoch.
+
+ :param Trainer trainer: The trainer instance managing the execution.
+ :param __: Placeholder argument, not used.
+ """
+ # Only collect metrics after the first epoch to ensure they are logged
+ if trainer.current_epoch > 0:
+
+ # Collect the metrics that are being tracked
+ tracked_metrics = {
+ k: v
+ for k, v in trainer.logged_metrics.items()
+ if k in self.metrics_to_track
+ }
+ self._collection.append(copy.deepcopy(tracked_metrics))
+
+ @property
+ def metrics(self):
+ """
+ Return the collected metrics stacked over the tracked epochs.
+
+ :return: The dictionary mapping each metric name to a tensor containing
+ its values across epochs. Returns an empty dictionary if no metrics
+ have been collected.
+ :rtype: dict[str, torch.Tensor]
+ """
+ if not self._collection:
+ return {}
+
+ # Identify the common keys across all collected metric dictionaries
+ common_keys = set(self._collection[0]).intersection(
+ *self._collection[1:]
+ )
+
+ return {
+ k: torch.stack([dic[k] for dic in self._collection])
+ for k in common_keys
+ if k in self.metrics_to_track
+ }
diff --git a/pina/_src/callback/processing/pina_progress_bar.py b/pina/_src/callback/processing/pina_progress_bar.py
new file mode 100644
index 000000000..bde274052
--- /dev/null
+++ b/pina/_src/callback/processing/pina_progress_bar.py
@@ -0,0 +1,116 @@
+"""Module for the Processing Callbacks."""
+
+from lightning.pytorch.callbacks import TQDMProgressBar
+from lightning.pytorch.callbacks.progress.progress_bar import (
+ get_standard_metrics,
+)
+from pina._src.core.utils import check_consistency
+
+
+class PINAProgressBar(TQDMProgressBar):
+ """
+ Custom progress bar callback for PINA training workflows.
+
+ This callback extends the default Lightning progress bar by filtering the
+ displayed metrics.
+
+ Metrics can refer either to condition-specific losses, identified by the
+ names assigned to the problem conditions, or to global losses. Global losses
+ are selected using ``"train"``, ``"val"``, or ``"test"``, and are internally
+ expanded to the corresponding logged loss metrics.
+ """
+
+ GLOBAL_LOSS_KEYS = ("train", "val", "test")
+
+ BAR_FORMAT = (
+ "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, "
+ "{rate_noinv_fmt}{postfix}]"
+ )
+
+ def __init__(self, metrics="val", **kwargs):
+ """
+ Initialization of the :class:`PINAProgressBar`.
+
+ :param metrics: The names of the metrics to be shown in the progress
+ bar. Each entry can be either a key of a condition defined in the
+ problem or one of the global loss keys: ``"train"``, ``"val"``, or
+ ``"test"``. These global keys are internally expanded to the
+ corresponding logged loss names. Default is ``"val"``.
+ :type metrics: str | list(str) | tuple(str)
+ :param dict kwargs: Additional keyword arguments passed to
+ :class:`lightning.pytorch.callbacks.TQDMProgressBar`.
+ :raises TypeError: If ``metrics`` contains non-string elements.
+ """
+ super().__init__(**kwargs)
+
+ # Check consistency
+ check_consistency(metrics, str)
+
+ # Convert to list if a single string is provided
+ if isinstance(metrics, str):
+ metrics = [metrics]
+
+ # Store the sorted metrics for later use in get_metrics
+ self._sorted_metrics = sorted(metrics)
+
+ def get_metrics(self, trainer, __):
+ """
+ Retrieve and filter metrics to be displayed in the progress bar.
+
+ This method combines standard Lightning metrics with user-selected
+ progress bar metrics, retaining only the metrics specified at
+ initialization.
+
+ :param Trainer trainer: The trainer managing the training loop.
+ :param __: Placeholder argument, not used.
+ :return: Dictionary containing the metrics to display.
+ :rtype: dict
+
+ .. note::
+ This method overrides the default Lightning behavior. It can be
+ further customized by subclassing.
+ """
+ # Retrieve standard metrics and user-selected progress bar metrics
+ standard_metrics = get_standard_metrics(trainer)
+ progress_bar_metrics = trainer.progress_bar_metrics
+
+ # Filter progress bar metrics to include only specified keys
+ if progress_bar_metrics:
+ progress_bar_metrics = {
+ key: progress_bar_metrics[key]
+ for key in progress_bar_metrics
+ if key in self._sorted_metrics
+ }
+
+ return {**standard_metrics, **progress_bar_metrics}
+
+ def setup(self, trainer, pl_module, stage):
+ """
+ Configure the metrics to track before execution starts.
+
+ The requested metrics must be either names assigned to problem
+ conditions or global loss keys. The accepted global loss keys are
+ ``"train"``, ``"val"``, and ``"test"``.
+
+ :param Trainer trainer: The trainer instance managing the execution.
+ :param BaseSolver pl_module: The solver module being executed.
+ :param str stage: Current execution stage.
+ :raises KeyError: If a metric key is neither a condition key nor one of
+ ``"train"``, ``"val"``, or ``"test"``.
+ """
+ # Get the condition keys from the problem
+ condition_keys = trainer.solver.problem.conditions.keys()
+ for key in self._sorted_metrics:
+ if key not in condition_keys and key not in self.GLOBAL_LOSS_KEYS:
+ raise KeyError(
+ f"Key '{key}' is not a valid metric. It must be either a "
+ f"problem condition key or one of {self.GLOBAL_LOSS_KEYS}."
+ )
+
+ # Add the appropriate suffix to the metric names based on batch size
+ suffix = "_loss_epoch" if trainer.batch_size is not None else "_loss"
+ self._sorted_metrics = [
+ metric + suffix for metric in self._sorted_metrics
+ ]
+
+ return super().setup(trainer, pl_module, stage)
diff --git a/pina/_src/callback/refinement/__init__.py b/pina/_src/callback/refinement/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pina/_src/callback/refinement/base_refinement.py b/pina/_src/callback/refinement/base_refinement.py
new file mode 100644
index 000000000..640528975
--- /dev/null
+++ b/pina/_src/callback/refinement/base_refinement.py
@@ -0,0 +1,155 @@
+"""Module for the Base Refinement class."""
+
+from pina._src.solver.physics_informed_single_model_solver import (
+ PhysicsInformedSingleModelSolver,
+)
+from lightning.pytorch import Callback
+from pina._src.core.utils import check_consistency, check_positive_integer
+from pina._src.callback.refinement.refinement_interface import (
+ RefinementInterface,
+)
+
+
+class BaseRefinement(Callback, RefinementInterface):
+ """
+ Base class for all refinement strategies, implementing common functionality.
+
+ A refinement strategy is responsible for dynamically updating the training
+ dataset during optimization, typically by resampling points in the domain
+ based on model behavior (e.g., error-driven refinement).
+
+ All specific refinement strategies should inherit from this class and
+ implement its abstract methods.
+
+ This class is not meant to be instantiated directly.
+ """
+
+ def __init__(self, sample_every, condition_to_update=None):
+ """
+ Initialization of the :class:`BaseRefinement` class.
+
+ :param int sample_every: The number of epochs between successive
+ refinement steps.
+ :param condition_to_update: The condition(s) to be updated during
+ refinement. If ``None``, all conditions associated with a domain are
+ updated. Default is ``None``.
+ :type condition_to_update: str | list[str] | tuple[str]
+ :raises AssertionError: If ``sample_every`` is not a positive integer.
+ :raises ValueError: If ``condition_to_update``, when provided, is not a
+ string or an iterable of strings.
+ """
+ # Check consistency
+ check_positive_integer(sample_every, strict=True)
+ if condition_to_update is not None:
+ if isinstance(condition_to_update, str):
+ condition_to_update = [condition_to_update]
+ check_consistency([condition_to_update], (list, tuple))
+ check_consistency(condition_to_update, str)
+
+ # Initialize attributes
+ self._condition_to_update = condition_to_update
+ self.sample_every = sample_every
+ self._initial_population_size = None
+ self._dataset = None
+
+ def on_train_start(self, trainer, solver):
+ """
+ This method is called once before training begins and is typically used
+ to initialize datasets, sampling conditions, or internal state.
+
+ :param Trainer trainer: The trainer managing the training loop.
+ :param BaseSolver solver: The solver associated with the trainer.
+ :raise RuntimeError: If the solver is not physics-informed (i.e., does
+ not implement PINNInterface).
+ :raise RuntimeError: If any of the specified conditions do not exist in
+ the problem.
+ :raise RuntimeError: If any of the specified conditions do not have a
+ 'domain' attribute for sampling.
+ """
+ # Check solver consistency
+ if not isinstance(solver, PhysicsInformedSingleModelSolver):
+ raise RuntimeError(
+ "Refinement strategies require a physics-informed solver. "
+ f"Got '{type(solver).__name__}'."
+ )
+
+ # Initialize conditions to update if not provided
+ if self._condition_to_update is None:
+ self._condition_to_update = [
+ name
+ for name, cond in solver.problem.conditions.items()
+ if hasattr(cond, "domain")
+ ]
+
+ # Validate conditions and solver
+ for cond in self._condition_to_update:
+
+ # Check if condition exists in the problem
+ if cond not in solver.problem.conditions:
+ raise RuntimeError(
+ f"Unknown condition '{cond}'. Available conditions: "
+ f"{list(solver.problem.conditions.keys())}."
+ )
+
+ # Check if condition has a domain to sample from
+ if not hasattr(solver.problem.conditions[cond], "domain"):
+ raise RuntimeError(
+ f"Condition '{cond}' has no 'domain' attribute and cannot "
+ "be used for sampling."
+ )
+
+ # Initialize dataset and compute initial population size
+ self._dataset = trainer.datamodule.train_datasets
+ self._initial_population_size = {
+ cond: self.dataset[cond].dataset_length
+ for cond in self._condition_to_update
+ }
+
+ def on_train_epoch_end(self, trainer, solver):
+ """
+ Apply refinement at the end of a training epoch.
+
+ This method is invoked after each epoch and can update the dataset based
+ on the current state of the model.
+
+ :param Trainer trainer: The trainer managing the training loop.
+ :param BaseSolver solver: The solver associated with the trainer.
+ """
+ # Store current epoch
+ epoch = trainer.current_epoch
+
+ # Sample if it's time to refine
+ if epoch % self.sample_every == 0 and epoch != 0:
+
+ # Update points for each condition to update
+ for name in self._condition_to_update:
+
+ current_points = solver.problem.conditions[name].data.input
+ new_points = self.sample(current_points, name, solver)
+ solver.problem.conditions[name].data.input = new_points
+
+ @property
+ def dataset(self):
+ """
+ The training datasets managed by the refinement strategy.
+
+ The dataset is stored as a dictionary whose keys are condition names and
+ whose values are the corresponding dataset subsets. The content of this
+ dictionary can be updated dynamically during refinement.
+
+ :return: The mapping between condition names and dataset subsets.
+ :rtype: dict
+ """
+ return self._dataset
+
+ @property
+ def initial_population_size(self):
+ """
+ Initial size of the sampled dataset for each condition before any
+ refinement is applied.
+
+ :return: A mapping between each condition name and its initial number
+ of sampled points.
+ :rtype: dict[str, int]
+ """
+ return self._initial_population_size
diff --git a/pina/_src/callback/refinement/r3_refinement.py b/pina/_src/callback/refinement/r3_refinement.py
new file mode 100644
index 000000000..1e52c4b4b
--- /dev/null
+++ b/pina/_src/callback/refinement/r3_refinement.py
@@ -0,0 +1,105 @@
+"""Module for the R3Refinement callback."""
+
+import torch
+from pina._src.core.utils import check_consistency
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.loss.dual_loss_interface import DualLossInterface
+from pina._src.callback.refinement.base_refinement import BaseRefinement
+
+
+class R3Refinement(BaseRefinement):
+ """
+ Refinement strategy based on the R3 (Retain-Resample-Release) algorithm.
+
+ This method adaptively updates collocation points by retaining points with
+ high residuals, resampling new points in the domain, releasing points with
+ low residuals.
+
+ The objective is to concentrate sampling in regions where the PDE residual
+ is large, improving training efficiency and solution accuracy.
+
+ .. seealso::
+
+ **Original Reference**: Daw, Arka, et al. (2023).
+ *Mitigating Propagation Failures in Physics-informed Neural Networks
+ using Retain-Resample-Release (R3) Sampling*.
+ DOI: `10.48550/arXiv.2207.02338
+ `_
+
+ :Example:
+
+ >>> r3 = R3Refinement(sample_every=5)
+ """
+
+ def __init__(
+ self,
+ sample_every,
+ residual_loss=torch.nn.L1Loss,
+ condition_to_update=None,
+ ):
+ """
+ Initialization of the :class:`R3Refinement` class.
+
+ :param int sample_every: The number of epochs between successive
+ refinement steps.
+ :param residual_loss: The loss used to evaluate residual magnitude. Must
+ be a subclass of :class:`torch.nn.Module` or
+ :class:`pina.loss.DualLossInterface`.
+ Default is :class:`torch.nn.L1Loss`.
+ :type residual_loss: DualLossInterface | torch.nn.modules.loss._Loss
+ :param condition_to_update: The condition(s) to be updated during
+ refinement. If ``None``, all conditions associated with a domain are
+ updated. Default is ``None``.
+ :type condition_to_update: str | list[str] | tuple[str]
+ :raises ValueError: If the condition_to_update is neither a string nor
+ an iterable of strings.
+ :raises ValueError: If the residual_loss is not a valid loss class.
+ """
+ super().__init__(sample_every, condition_to_update)
+
+ # Check consistency
+ check_consistency(
+ residual_loss,
+ (DualLossInterface, torch.nn.modules.loss._Loss),
+ subclass=True,
+ )
+
+ # Store the loss function for computing residuals during sampling
+ self.loss_fn = residual_loss(reduction="none")
+
+ def sample(self, current_points, condition_name, solver):
+ """
+ Generate new sample points for a given condition.
+
+ :param LabelTensor current_points: The existing points in the domain.
+ :param str condition_name: The identifier of the condition to refine.
+ :param BaseSolver solver: The solver used for sampling decisions.
+ :return: Newly sampled points.
+ :rtype: LabelTensor
+ """
+ # Retrieve condition and current points
+ device = solver.trainer.strategy.root_device
+ condition = solver.problem.conditions[condition_name]
+ current_points = current_points.to(device).requires_grad_(True)
+
+ # Compute residuals for the given condition
+ target = condition.evaluate({"input": current_points}, solver)
+ residuals = self.loss_fn(target, torch.zeros_like(target)).mean(
+ dim=tuple(range(1, target.ndim))
+ )
+
+ # Retrieve domain and initial population size
+ domain_name = solver.problem.conditions[condition_name].domain
+ domain = solver.problem.domains[domain_name]
+ num_old_points = self.initial_population_size[condition_name]
+
+ # Select points with residual above the mean
+ mask = (residuals >= residuals.mean()).flatten()
+ high_residual_pts = current_points[mask]
+ high_residual_pts.labels = current_points.labels
+
+ # Sample new points to maintain the initial population size
+ num_new_pts = max(num_old_points - len(high_residual_pts), 0)
+ samples = domain.sample(num_new_pts, "random").to(device)
+
+ return LabelTensor.cat([high_residual_pts, samples])
diff --git a/pina/_src/callback/refinement/refinement_interface.py b/pina/_src/callback/refinement/refinement_interface.py
new file mode 100644
index 000000000..320d526af
--- /dev/null
+++ b/pina/_src/callback/refinement/refinement_interface.py
@@ -0,0 +1,69 @@
+"""Module for the Refinement Interface."""
+
+from abc import ABCMeta, abstractmethod
+
+
+class RefinementInterface(metaclass=ABCMeta):
+ """
+ Abstract interface for all refinement strategies.
+ """
+
+ @abstractmethod
+ def on_train_start(self, trainer, solver):
+ """
+ This method is called once before training begins and is typically used
+ to initialize datasets, sampling conditions, or internal state.
+
+ :param Trainer trainer: The trainer managing the training loop.
+ :param BaseSolver solver: The solver associated with the trainer.
+ """
+
+ @abstractmethod
+ def on_train_epoch_end(self, trainer, solver):
+ """
+ Apply refinement at the end of a training epoch.
+
+ This method is invoked after each epoch and can update the dataset based
+ on the current state of the model.
+
+ :param Trainer trainer: The trainer managing the training loop.
+ :param BaseSolver solver: The solver associated with the trainer.
+ """
+
+ @abstractmethod
+ def sample(self, current_points, condition_name, solver):
+ """
+ Generate new sample points for a given condition.
+
+ :param LabelTensor current_points: The existing points in the domain.
+ :param str condition_name: The identifier of the condition to refine.
+ :param BaseSolver solver: The solver used for sampling decisions.
+ :return: Newly sampled points.
+ :rtype: LabelTensor
+ """
+
+ @property
+ @abstractmethod
+ def dataset(self):
+ """
+ The training datasets managed by the refinement strategy.
+
+ The dataset is stored as a dictionary whose keys are condition names and
+ whose values are the corresponding dataset subsets. The content of this
+ dictionary can be updated dynamically during refinement.
+
+ :return: The mapping between condition names and dataset subsets.
+ :rtype: dict
+ """
+
+ @property
+ @abstractmethod
+ def initial_population_size(self):
+ """
+ Initial size of the sampled dataset for each condition before any
+ refinement is applied.
+
+ :return: A mapping between each condition name and its initial number
+ of sampled points.
+ :rtype: dict[str, int]
+ """
diff --git a/pina/_src/condition/__init__.py b/pina/_src/condition/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pina/_src/condition/base_condition.py b/pina/_src/condition/base_condition.py
new file mode 100644
index 000000000..9adf19b2d
--- /dev/null
+++ b/pina/_src/condition/base_condition.py
@@ -0,0 +1,153 @@
+"""Module for the Base Condition class."""
+
+from functools import partial
+import torch
+from torch_geometric.data import Batch
+from torch.utils.data import DataLoader
+from pina._src.condition.condition_interface import ConditionInterface
+from pina._src.core.graph import LabelBatch
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.core.utils import check_consistency
+from pina._src.data.single_batch_data_loader import _SingleBatchDataLoader
+from pina._src.problem.problem_interface import ProblemInterface
+
+
+class BaseCondition(ConditionInterface):
+ """
+ Base class for all conditions, implementing common functionality.
+
+ All specific condition types should inherit from this class and implement
+ the abstract methods of
+ :class:`~pina.condition.condition_interface.ConditionInterface`.
+
+ This class is not meant to be instantiated directly.
+ """
+
+ # Available collate functions for automatic batching
+ collate_fn_dict = {
+ "tensor": torch.stack,
+ "label_tensor": LabelTensor.stack,
+ "graph": LabelBatch.from_data_list,
+ "data": Batch.from_data_list,
+ }
+
+ def __init__(self, **kwargs):
+ """
+ Initialization of the :class:`BaseCondition` class.
+
+ :param dict kwargs: The keyword arguments representing the data to be
+ stored in the condition.
+ """
+ super().__init__()
+ self.data = self.store_data(**kwargs)
+ self.has_custom_dataloader_fn = False
+
+ def __len__(self):
+ """
+ Return the number of data points in the condition.
+
+ :return: The number of data points.
+ :rtype: int
+ """
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ """
+ Return the data point at the specified index.
+
+ :param int idx: The index of the data point to retrieve.
+ :return: The data point at the specified index.
+ :rtype: Any
+ """
+ return self.data[idx]
+
+ def create_dataloader(
+ self, dataset, batch_size, automatic_batching, **kwargs
+ ):
+ """
+ Create the DataLoader for the condition.
+
+ :param _ConditionSubset dataset: The dataset for the DataLoader.
+ :param int batch_size: The batch size for the DataLoader.
+ :param bool automatic_batching: Whether to use automatic batching.
+ :param dict kwargs: Additional keyword arguments for the DataLoader.
+ :return: The DataLoader for the condition.
+ :rtype: torch.utils.data.DataLoader
+ """
+ # If batching the entire dataset, return a _SingleBatchDataLoader
+ if batch_size == len(dataset):
+ return _SingleBatchDataLoader(dataset)
+
+ # Otherwise, return a regular DataLoader with the appropriate collate
+ return DataLoader(
+ dataset=dataset,
+ collate_fn=(
+ partial(self.collate_fn, condition=self)
+ if not automatic_batching
+ else self.automatic_batching_collate_fn
+ ),
+ batch_size=batch_size,
+ **kwargs,
+ )
+
+ def switch_dataloader_fn(self, create_dataloader_fn):
+ """
+ Switch the dataloader function for the condition.
+
+ :param Callable create_dataloader_fn: The new dataloader function to use
+ for the condition.
+ :return: The new dataloader function for the condition.
+ :rtype: Callable
+ """
+ self.has_custom_dataloader_fn = True
+ self.create_dataloader = create_dataloader_fn
+
+ @classmethod
+ def automatic_batching_collate_fn(cls, batch):
+ """
+ Collate function for automatic batching to be used in the DataLoader.
+
+ :param list batch: A list of items from the dataset.
+ :return: A collated batch.
+ :rtype: dict
+ """
+ # If the batch is empty, return an empty dictionary
+ if not batch:
+ return {}
+
+ # Otherwise, collate the batch using the appropriate collate function
+ instance_class = batch[0].__class__
+ return instance_class.create_batch(batch)
+
+ @staticmethod
+ def collate_fn(batch, condition):
+ """
+ Collate function for custom batching to be used in the DataLoader.
+
+ :param list batch: A list of items from the dataset.
+ :param BaseCondition condition: The condition instance.
+ :return: A collated batch.
+ :rtype: dict
+ """
+ return condition.data[batch].to_batch()
+
+ @property
+ def problem(self):
+ """
+ The problem associated with this condition.
+
+ :return: The problem associated with this condition.
+ :rtype: BaseProblem
+ """
+ return self._problem
+
+ @problem.setter
+ def problem(self, value):
+ """
+ Set the problem associated with this condition.
+
+ :param BaseProblem value: The problem to associate with this condition.
+ :raises ValueError: If the problem is not an instance of BaseProblem.
+ """
+ check_consistency(value, ProblemInterface)
+ self._problem = value
diff --git a/pina/_src/condition/condition.py b/pina/_src/condition/condition.py
new file mode 100644
index 000000000..69875a6a8
--- /dev/null
+++ b/pina/_src/condition/condition.py
@@ -0,0 +1,167 @@
+"""Module for the Condition class."""
+
+from pina._src.condition.input_equation_condition import InputEquationCondition
+from pina._src.condition.input_target_condition import InputTargetCondition
+from pina._src.condition.time_series_condition import TimeSeriesCondition
+from pina._src.condition.graph_time_series_condition import (
+ GraphTimeSeriesCondition,
+)
+from pina._src.condition.data_condition import DataCondition
+from pina._src.condition.domain_equation_condition import (
+ DomainEquationCondition,
+)
+
+
+class Condition:
+ """
+ The :class:`Condition` class is a core component of the PINA framework that
+ provides a unified interface to define heterogeneous constraints that must
+ be satisfied by a :class:`~pina.problem.base_problem.BaseProblem`.
+
+ It encapsulates all types of constraints - physical, boundary, initial, or
+ data-driven - that the solver must satisfy during training. The specific
+ behavior is inferred from the arguments passed to the constructor.
+
+ Multiple types of conditions can be used within the same problem, allowing
+ for a high degree of flexibility in defining complex problems.
+
+ The :class:`Condition` class behavior specializes internally based on the
+ arguments provided during instantiation. Depending on the specified keyword
+ arguments, the class automatically selects the appropriate internal
+ implementation.
+
+ Available `Condition` types:
+
+ - :class:`~pina.condition.input_target_condition.InputTargetCondition`:
+ represents a supervised condition defined by both ``input`` and ``target``
+ data. The model is trained to reproduce the ``target`` values given the
+ ``input``. Supported data types include :class:`torch.Tensor`,
+ :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or
+ :class:`~torch_geometric.data.Data`. The class automatically selects the
+ appropriate implementation based on the types of ``input`` and ``target``.
+
+ - :class:`~pina.condition.domain_equation_condition.DomainEquationCondition`
+ : represents a general physics-informed condition defined by a ``domain``
+ and an ``equation``. The model learns to minimize the equation residual
+ through evaluations performed at points sampled from the specified domain.
+
+ - :class:`~pina.condition.input_equation_condition.InputEquationCondition`:
+ represents a general physics-informed condition defined by ``input``
+ points and an ``equation``. The model learns to minimize the equation
+ residual through evaluations performed at the provided ``input``.
+ Supported data types for the ``input`` include :class:`~pina.graph.Graph`
+ or :class:`~pina.label_tensor.LabelTensor`. The class automatically
+ selects the appropriate implementation based on the types of the
+ ``input``.
+
+ - :class:`~pina.condition.time_series_condition.TimeSeriesCondition`:
+ represents a condition designed for time series data, where the model is
+ trained to capture temporal dependencies and dynamics. It is defined by an
+ ``input`` tensor of shape ``[trajectories, time_steps, *features]``
+ containing time series data. Supported data types for the ``input``
+ include class:`~pina.label_tensor.LabelTensor` or :class:`torch.Tensor`.
+ The class automatically selects the appropriate implementation based on
+ the type of the ``input``.
+
+ - :class:`~pina.condition.data_condition.DataCondition`: represents an
+ unsupervised, data-driven condition defined by the ``input`` only.
+ The model is trained using a custom unsupervised loss determined by the
+ chosen :class:`~pina.solver.base_solver.BaseSolver`, while leveraging the
+ provided data during training. Optional ``conditional_variables`` can be
+ specified when the model depends on additional parameters.
+ Supported data types include :class:`~pina.label_tensor.LabelTensor`,
+ :class:`torch.Tensor`, :class:`~torch_geometric.data.Data`, or
+ :class:`~pina.graph.Graph`. The class automatically selects the
+ appropriate implementation based on the type of the ``input``.
+
+ .. note::
+
+ The user should always instantiate :class:`Condition` directly, without
+ manually creating subclass instances. Please refer to the specific
+ :class:`Condition` classes for implementation details.
+
+ :Example:
+
+ >>> from pina import Condition
+
+ >>> # Example of InputTargetCondition signature
+ >>> condition = Condition(input=input, target=target)
+
+ >>> # Example of DomainEquationCondition signature
+ >>> condition = Condition(domain=domain, equation=equation)
+
+ >>> # Example of InputEquationCondition signature
+ >>> condition = Condition(input=input, equation=equation)
+
+ >>> # Example of TimeSeriesCondition signature
+ >>> condition = Condition(
+ ... input=input, n_windows=n_windows, unroll_length=unroll_length
+ ... )
+
+ >>> # Example of DataCondition signature
+ >>> condition = Condition(input=data, conditional_variables=cond_vars)
+ """
+
+ # Internal specifications for condition types, used for dispatching
+ # Each tuple contains: (condition class, required kwargs, optional kwargs)
+ _SPECS = (
+ (InputTargetCondition, {"input", "target"}, set()),
+ (InputEquationCondition, {"input", "equation"}, set()),
+ (DomainEquationCondition, {"domain", "equation"}, set()),
+ (DataCondition, {"input"}, {"conditional_variables"}),
+ (
+ TimeSeriesCondition,
+ {"input", "n_windows", "unroll_length"},
+ {"randomize"},
+ ),
+ (
+ GraphTimeSeriesCondition,
+ {"input", "n_windows", "unroll_length"},
+ {"key", "randomize"},
+ ),
+ )
+
+ # Compute the set of all available keyword arguments (optional + required)
+ available_kwargs = sorted(set().union(*(rq | op for _, rq, op in _SPECS)))
+
+ def __new__(cls, *args, **kwargs):
+ """
+ Instantiate the appropriate :class:`Condition` object based on the
+ keyword arguments passed.
+
+ :param tuple args: The positional arguments (should be empty).
+ :param dict kwargs: The keyword arguments corresponding to the
+ parameters of the specific :class:`Condition` type to instantiate.
+ :raises ValueError: If unexpected positional arguments are provided.
+ :raises ValueError: If the keyword arguments do not match any valid
+ signature for the available condition types.
+ :return: The appropriate :class:`Condition` object.
+ :rtype: ConditionInterface
+ """
+ # Ensure no positional arguments are provided
+ if args:
+ raise ValueError(
+ "Condition takes only keyword arguments. "
+ f"Available arguments are: {cls.available_kwargs}."
+ )
+
+ # Iterate through the specifications to find a matching condition type
+ for condition_cls, required, optional in cls._SPECS:
+
+ # Find allowed keys for condition type
+ allowed = required | optional
+
+ # Check if the provided keys match the required and optional keys
+ if required <= set(kwargs) <= allowed:
+ return condition_cls(**kwargs)
+
+ # If no valid signature is found, prepare a list of valid signatures
+ valid_signatures = [
+ sorted(required | optional) for _, required, optional in cls._SPECS
+ ]
+
+ # If no valid signature is found, raise an error
+ raise ValueError(
+ f"Invalid keyword arguments {sorted(set(kwargs))}. "
+ f"Valid signatures are: {valid_signatures}."
+ )
diff --git a/pina/_src/condition/condition_interface.py b/pina/_src/condition/condition_interface.py
new file mode 100644
index 000000000..baa6a5d99
--- /dev/null
+++ b/pina/_src/condition/condition_interface.py
@@ -0,0 +1,133 @@
+"""Module for the Condition interface."""
+
+from abc import ABCMeta, abstractmethod
+
+
+class ConditionInterface(metaclass=ABCMeta):
+ """
+ Abstract interface for all conditions.
+
+ Refer to :class:`pina.condition.condition.Condition` for a thorough
+ description of all available conditions and how to instantiate them.
+ """
+
+ @abstractmethod
+ def __len__(self):
+ """
+ Return the number of data points in the condition.
+
+ :return: The number of data points.
+ :rtype: int
+ """
+
+ @abstractmethod
+ def __getitem__(self, idx):
+ """
+ Return the data point at the specified index.
+
+ :param int idx: The index of the data point to retrieve.
+ :return: The data point at the specified index.
+ :rtype: Any
+ """
+
+ @abstractmethod
+ def store_data(self, **kwargs):
+ """
+ Store the data for the condition in a suitable format.
+
+ :param dict kwargs: The keyword arguments containing the data to be
+ stored.
+ :return: The stored data in a suitable format.
+ :rtype: Any
+ """
+
+ @abstractmethod
+ def create_dataloader(
+ self, dataset, batch_size, automatic_batching, **kwargs
+ ):
+ """
+ Create the DataLoader for the condition.
+
+ :param _ConditionSubset dataset: The dataset for the DataLoader.
+ :param int batch_size: The batch size for the DataLoader.
+ :param bool automatic_batching: Whether to use automatic batching.
+ :param dict kwargs: Additional keyword arguments for the DataLoader.
+ :return: The DataLoader for the condition.
+ :rtype: torch.utils.data.DataLoader
+ """
+
+ @abstractmethod
+ def evaluate(self, batch, solver):
+ """
+ Evaluate the residual of the condition on the given batch using the
+ solver.
+
+ This method computes the non-aggregated, element-wise residual of the
+ condition. A forward pass of the solver's model is performed on the
+ input samples, and the condition residual is evaluated accordingly.
+
+ The returned tensor is not reduced, preserving the per-sample residual
+ values.
+
+ :param dict batch: The batch containing the data required by the
+ condition evaluation.
+ :param BaseSolver solver: The solver used to perform the forward pass
+ and compute the residual. The solver provides access to the model
+ and its parameters, which may be necessary for evaluating the
+ condition residual.
+ :return: The non-aggregated residual tensor.
+ :rtype: torch.Tensor | LabelTensor
+ """
+
+ @abstractmethod
+ def switch_dataloader_fn(self, create_dataloader_fn):
+ """
+ Switch the dataloader function for the condition.
+
+ :param Callable create_dataloader_fn: The new dataloader function to use
+ for the condition.
+ :return: The new dataloader function for the condition.
+ :rtype: Callable
+ """
+
+ @classmethod
+ @abstractmethod
+ def automatic_batching_collate_fn(cls, batch):
+ """
+ Collate function for automatic batching to be used in the DataLoader.
+
+ :param list batch: A list of items from the dataset.
+ :return: A collated batch.
+ :rtype: dict
+ """
+
+ @staticmethod
+ @abstractmethod
+ def collate_fn(batch, condition):
+ """
+ Collate function for custom batching to be used in the DataLoader.
+
+ :param list batch: A list of items from the dataset.
+ :param BaseCondition condition: The condition instance.
+ :return: A collated batch.
+ :rtype: dict
+ """
+
+ @property
+ @abstractmethod
+ def problem(self):
+ """
+ The problem associated with this condition.
+
+ :return: The problem associated with this condition.
+ :rtype: BaseProblem
+ """
+
+ @problem.setter
+ @abstractmethod
+ def problem(self, value):
+ """
+ Set the problem associated with this condition.
+
+ :param BaseProblem value: The problem to associate with this condition.
+ """
diff --git a/pina/_src/condition/data_condition.py b/pina/_src/condition/data_condition.py
new file mode 100644
index 000000000..ee6ee4c76
--- /dev/null
+++ b/pina/_src/condition/data_condition.py
@@ -0,0 +1,133 @@
+"""Module for the Data Condition class."""
+
+import torch
+from torch_geometric.data import Data
+from pina._src.condition.base_condition import BaseCondition
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.core.graph import Graph
+from pina._src.data.manager.data_manager import _DataManager
+from pina._src.core.utils import check_consistency
+
+
+class DataCondition(BaseCondition):
+ """
+ The class :class:`DataCondition` defines an unsupervised condition based on
+ ``input`` data. This condition is typically used in data-driven problems,
+ where the model is trained using a custom unsupervised loss determined by
+ the chosen :class:`~pina.solver.base_solver.BaseSolver`, while leveraging
+ the provided data during training. Optional ``conditional_variables`` can be
+ specified when the model depends on additional parameters.
+
+ :Example:
+
+ >>> from pina import Condition, LabelTensor
+ >>> import torch
+
+ >>> pts = LabelTensor(torch.randn(100, 2), labels=["x", "y"])
+ >>> cond_vars = LabelTensor(torch.randn(100, 1), labels=["w"])
+ >>> condition = Condition(input=pts, conditional_variables=cond_vars)
+ """
+
+ # Available fields, input and conditional variables data types
+ __fields__ = ["input", "conditional_variables"]
+ _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph)
+ _avail_conditional_variables_cls = (torch.Tensor, LabelTensor)
+
+ def __new__(cls, input, conditional_variables=None):
+ """
+ Check the types of ``input`` and ``conditional_variables`` and
+ instantiate an instance of :class:`DataCondition` accordingly.
+
+ :param input: The input data associated with the condition.
+ :type input: torch.Tensor | LabelTensor | Graph |
+ Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data]
+ :param conditional_variables: The conditional variables associated with
+ the condition. Default is ``None``.
+ :type conditional_variables: torch.Tensor | LabelTensor
+ :raises ValueError: If ``input`` is not of type :class:`torch.Tensor`,
+ :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`,
+ or :class:`~torch_geometric.data.Data`, nor is it a list or tuple of
+ :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`.
+ :raises ValueError: If ``conditional_variables`` is not of type
+ :class:`torch.Tensor` or :class:`~pina.label_tensor.LabelTensor`.
+ :return: A new instance of :class:`DataCondition`.
+ :rtype: DataCondition
+ """
+ # Check input type - if iterable, ensure it is either Data or Graph
+ if isinstance(input, (list, tuple)):
+ check_consistency(input, (Data, Graph))
+ else:
+ check_consistency(input, cls._avail_input_cls)
+
+ # Check conditional_variables type
+ if conditional_variables is not None:
+ check_consistency(
+ conditional_variables, cls._avail_conditional_variables_cls
+ )
+
+ return super().__new__(cls)
+
+ def store_data(self, **kwargs):
+ """
+ Store the input data and the conditional variables in a dictionary-like
+ structure.
+
+ :param dict kwargs: The keyword arguments containing the data to be
+ stored.
+ :return: A dictionary-like structure containing the stored data.
+ :rtype: _DataManager
+ """
+ # Store input and conditional variables in a dictionary-like structure
+ data_dict = {"input": kwargs.get("input")}
+ cond_vars = kwargs.get("conditional_variables", None)
+ if cond_vars is not None:
+ data_dict["conditional_variables"] = cond_vars
+
+ return _DataManager(**data_dict)
+
+ def evaluate(self, batch, solver):
+ """
+ Evaluate the residual of the condition on the given batch using the
+ solver.
+
+ This method computes the non-aggregated, element-wise residual of the
+ condition. A forward pass of the solver's model is performed on the
+ input samples, and the condition residual is evaluated accordingly.
+
+ The returned tensor is not reduced, preserving the per-sample residual
+ values.
+
+ :param dict batch: The batch containing the data required by the
+ condition evaluation.
+ :param BaseSolver solver: The solver used to perform the forward pass
+ and compute the residual. The solver provides access to the model
+ and its parameters, which may be necessary for evaluating the
+ condition residual.
+ :return: The non-aggregated residual tensor.
+ :rtype: torch.Tensor | LabelTensor
+ """
+ return solver.forward(batch["input"])
+
+ @property
+ def conditional_variables(self):
+ """
+ The conditional variables associated with the condition.
+
+ :return: The conditional variables.
+ :rtype: torch.Tensor | LabelTensor | None
+ """
+ if hasattr(self.data, "conditional_variables"):
+ return self.data.conditional_variables
+
+ return None
+
+ @property
+ def input(self):
+ """
+ The input data associated with the condition.
+
+ :return: The input data.
+ :rtype: torch.Tensor | LabelTensor | Graph | Data |
+ list[Graph] | list[Data] | tuple[Graph] | tuple[Data]
+ """
+ return self.data.input
diff --git a/pina/_src/condition/domain_equation_condition.py b/pina/_src/condition/domain_equation_condition.py
new file mode 100644
index 000000000..4641d3933
--- /dev/null
+++ b/pina/_src/condition/domain_equation_condition.py
@@ -0,0 +1,153 @@
+"""Module for the Domain-Equation Condition class."""
+
+from pina._src.condition.base_condition import BaseCondition
+from pina._src.domain.domain_interface import DomainInterface
+from pina._src.equation.base_equation import BaseEquation
+from pina._src.core.utils import check_consistency
+
+
+class DomainEquationCondition(BaseCondition):
+ """
+ The class :class:`DomainEquationCondition` defines a condition based on a
+ ``domain`` and an ``equation``. This condition is typically used in
+ physics-informed problems, where the model is trained to satisfy a given
+ ``equation`` over a specified ``domain``. The ``domain`` is used to sample
+ points where the ``equation`` residual is evaluated and minimized during
+ training.
+
+ :Example:
+
+ >>> from pina.domain import CartesianDomain
+ >>> from pina.equation import Equation
+ >>> from pina import Condition
+
+ >>> # Equation to be satisfied over the domain: # x^2 + y^2 - 1 = 0
+ >>> def dummy_equation(pts):
+ ... return pts["x"]**2 + pts["y"]**2 - 1
+
+ >>> domain = CartesianDomain({"x": [0, 1], "y": [0, 1]})
+ >>> condition = Condition(domain=domain, equation=Equation(dummy_equation))
+ """
+
+ # Available fields, domain and equation data types
+ __fields__ = ["domain", "equation"]
+ _avail_domain_cls = (DomainInterface, str)
+ _avail_equation_cls = BaseEquation
+
+ def __len__(self):
+ """
+ Return the number of data points in the condition.
+
+ :raises NotImplementedError: Always raised since the number of points is
+ determined by the domain sampling strategy and is not fixed.
+ """
+ raise NotImplementedError(
+ "The number of data points in a DomainEquationCondition is not "
+ "fixed and is determined by the domain sampling strategy. "
+ "Therefore, the :meth:`__len__` method is not implemented for this "
+ "condition."
+ )
+
+ def __getitem__(self, idx):
+ """
+ Return the data point at the specified index.
+
+ :raises NotImplementedError: Always raised since the data points are not
+ stored in a list-like structure and cannot be accessed by index.
+ """
+ raise NotImplementedError(
+ "Data points in a DomainEquationCondition are not stored in a "
+ "list-like structure and cannot be accessed by index. Therefore, "
+ "the :meth:`__getitem__` method is not implemented for this "
+ "condition."
+ )
+
+ def store_data(self, **kwargs):
+ """
+ Store the domain and the equation for the condition. It sets the
+ attributes ``domain`` and ``equation`` of the condition instance based
+ on the provided keyword arguments.
+
+ :param dict kwargs: The keyword arguments containing the data to be
+ stored.
+ """
+ # Store domain and equation as attributes of the condition instance
+ setattr(self, "domain", kwargs.get("domain"))
+ setattr(self, "equation", kwargs.get("equation"))
+
+ def evaluate(self, batch, solver):
+ """
+ Evaluate the residual of the condition on the given batch using the
+ solver.
+
+ This method computes the non-aggregated, element-wise residual of the
+ condition. A forward pass of the solver's model is performed on the
+ input samples, and the condition residual is evaluated accordingly.
+
+ The returned tensor is not reduced, preserving the per-sample residual
+ values.
+
+ :param dict batch: The batch containing the data required by the
+ condition evaluation.
+ :param BaseSolver solver: The solver used to perform the forward pass
+ and compute the residual. The solver provides access to the model
+ and its parameters, which may be necessary for evaluating the
+ condition residual.
+ :raises NotImplementedError: Always raised since any domain-equation
+ condition is transformed into an input-equation condition before
+ evaluation, and the residual is computed using the input-equation
+ condition's evaluation method.
+ """
+ raise NotImplementedError(
+ "Domain-equation conditions are transformed into input-equation "
+ "conditions before evaluation, and the residual is computed using "
+ "the input-equation condition's evaluation method. Therefore, the "
+ "evaluate method is not implemented for domain-equation conditions."
+ )
+
+ @property
+ def equation(self):
+ """
+ The equation associated with the condition.
+
+ :return: The equation.
+ :rtype: BaseEquation
+ """
+ return self._equation
+
+ @equation.setter
+ def equation(self, value):
+ """
+ Set the equation associated with this condition.
+
+ :param BaseEquation value: The equation to associate with the condition.
+ :raises ValueError: If ``value`` is not an instance of
+ :class:`~pina.equation.base_equation.BaseEquation`.
+ """
+ # Check consistency
+ check_consistency(value, self._avail_equation_cls)
+ self._equation = value
+
+ @property
+ def domain(self):
+ """
+ The domain associated with the condition.
+
+ :return: The domain.
+ :rtype: DomainInterface
+ """
+ return self._domain
+
+ @domain.setter
+ def domain(self, value):
+ """
+ Set the domain associated with this condition.
+
+ :param DomainInterface value: The domain to associate with the
+ condition.
+ :raises ValueError: If ``value`` is neither a string nor an instance of
+ :class:`~pina.domain.domain_interface.DomainInterface`.
+ """
+ # Check consistency
+ check_consistency(value, self._avail_domain_cls)
+ self._domain = value
diff --git a/pina/_src/condition/graph_time_series_condition.py b/pina/_src/condition/graph_time_series_condition.py
new file mode 100644
index 000000000..77ece43d6
--- /dev/null
+++ b/pina/_src/condition/graph_time_series_condition.py
@@ -0,0 +1,147 @@
+"""Module for the TimeSeriesCondition class."""
+
+import torch
+from pina._src.core.utils import check_consistency, check_positive_integer
+from pina._src.data.manager.data_manager import _DataManager
+from pina._src.condition.time_series_condition import TimeSeriesCondition
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.condition.base_condition import BaseCondition
+from torch_geometric.data import Data
+from pina._src.core.graph import Graph
+
+
+class GraphTimeSeriesCondition(TimeSeriesCondition):
+ """
+ The :class:`TimeSeriesCondition` class represents an autoregressive time
+ series condition defined by temporal ``input`` data. The input is expected
+ to have shape ``[trajectories, time_steps, *features]``, where the second
+ dimension corresponds to the temporal evolution of each trajectory.
+
+ During training, the condition automatically extracts overlapping temporal
+ windows from the trajectories. The parameter ``unroll_length`` defines the
+ number of consecutive time steps contained in each temporal window, while
+ ``n_windows`` controls how many temporal windows are created from the
+ available trajectories.
+
+ Internally, the unrolled data is stored as a tensor of shape
+ ``[trajectories, n_windows, unroll_length, *features]``.
+
+ Supported data types include :class:`~pina.label_tensor.LabelTensor` and
+ :class:`torch.Tensor`.
+
+ :Example:
+
+ >>> from pina import Condition, LabelTensor
+ >>> import torch
+
+ >>> data = LabelTensor(torch.rand(5, 10, 2), labels=["u", "v"])
+ >>> condition = Condition(input=data, unroll_length=5, n_windows=3)
+ """
+
+ # Available fields and input data types
+ __fields__ = ["input", "unroll_length", "n_windows", "key", "randomize"]
+ _avail_input_cls = (Data, Graph)
+
+ def __new__(cls, input, n_windows, unroll_length, key="x", randomize=False):
+ # Check consistency
+ check_consistency(input, cls._avail_input_cls)
+ check_consistency(randomize, bool)
+ check_consistency(key, str)
+ check_positive_integer(n_windows, strict=True)
+ check_positive_integer(unroll_length, strict=True)
+
+ return BaseCondition.__new__(cls)
+
+ def store_data(self, **kwargs):
+ """
+ Store the unrolled time-series input data.
+
+ The method extracts the time-series input data and creates the temporal
+ windows based on the specified ``unroll_length`` and ``n_windows``.
+
+ :param dict kwargs: The keyword arguments containing the data to be
+ stored.
+ :return: A dictionary-like structure containing the stored data.
+ :rtype: _DataManager
+ """
+ # Extract unrolling parameters from kwargs
+ unroll_length = kwargs.get("unroll_length")
+ n_windows = kwargs.get("n_windows")
+ randomize = kwargs.get("randomize", False)
+ key = kwargs.get("key", "x")
+ graph = kwargs.get("input")
+
+ # Create unrolled windows from the input data
+ if not hasattr(graph, key):
+ raise ValueError(
+ f"The provided graph does not have the specified key '{key}'."
+ )
+
+ unrolled_data = self._unroll(
+ data=graph.__getattribute__(key),
+ n_windows=n_windows,
+ unroll_length=unroll_length,
+ randomize=randomize,
+ )
+ graph.__setattr__(key, unrolled_data)
+
+ return _DataManager(input=graph)
+
+ def evaluate(self, batch, solver):
+ """
+ Evaluate the residual of the condition on the given batch using the
+ solver.
+
+ This method computes the per-step residuals through autoregressive
+ unrolling. A forward pass of the solver's model is performed at each
+ time step, and the per-step residuals (predicted - target) are
+ returned as a stacked tensor.
+
+ The returned tensor preserves all per-step residual values without
+ reduction or loss aggregation.
+
+ :param dict batch: The batch containing the data required by the
+ condition evaluation.
+ :param SolverInterface solver: The solver used to perform the forward
+ pass and compute the residual. The solver provides access to the
+ model and its parameters, which may be necessary for evaluating the
+ condition residual.
+ :raises ValueError: If the input tensor in the batch has less than 4
+ dimensions.
+ :return: The stacked per-step residual tensor of shape
+ ``[time_steps - 1, trajectories, windows, *features]``.
+ :rtype: torch.Tensor | LabelTensor
+ """
+ # Raise error if input tensor does not have at least 4 dimensions
+ if batch["input"].x.dim() < 4:
+ raise ValueError(
+ "The provided input tensor must have at least 4 dimensions:"
+ " [trajectories, windows, time_steps, *features]."
+ f" Got shape {batch['input'].shape}."
+ )
+
+ # Copy the kwargs to avoid modifying the original settings
+ kwargs = solver._kwargs.copy()
+
+ # Extract the initial state and initialize the step-wise residuals list
+ current_state = batch["input"].x[:, :, 0, :]
+ residuals = []
+
+ # Iterate over the time steps
+ for step in range(1, batch["input"].x.shape[2]):
+
+ # Pre-process, forward, and post-process the current state
+ processed_input = solver.preprocess_step(current_state, **kwargs)
+ output = solver.forward(processed_input)
+ predicted_state = solver.postprocess_step(output, **kwargs)
+
+ # Retrieve the target and compute the step-wise residual
+ target_state = batch["input"].x[:, :, step, :]
+ step_residual = predicted_state - target_state
+ residuals.append(step_residual)
+
+ # Update the current state for the next iteration
+ current_state = predicted_state
+
+ # Stack the step-wise residuals
+ return torch.stack(residuals).as_subclass(torch.Tensor)
diff --git a/pina/_src/condition/input_equation_condition.py b/pina/_src/condition/input_equation_condition.py
new file mode 100644
index 000000000..8682f7af7
--- /dev/null
+++ b/pina/_src/condition/input_equation_condition.py
@@ -0,0 +1,136 @@
+"""Module for the Input-Equation Condition class."""
+
+from pina._src.condition.base_condition import BaseCondition
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.core.graph import Graph
+from pina._src.equation.base_equation import BaseEquation
+from pina._src.data.manager.data_manager import _DataManager
+from pina._src.core.utils import check_consistency
+
+
+class InputEquationCondition(BaseCondition):
+ """
+ The class :class:`InputEquationCondition` defines a condition based on
+ ``input`` data and an ``equation``. This condition is typically used in
+ physics-informed problems, where the model is trained to satisfy a given
+ ``equation`` through the evaluation of the residual performed at the
+ provided ``input``.
+
+ :Example:
+
+ >>> from pina import Condition, LabelTensor
+ >>> from pina.equation import Equation
+ >>> import torch
+
+ >>> # Equation to be satisfied over the input points: # x^2 + y^2 - 1 = 0
+ >>> def dummy_equation(pts):
+ ... return pts["x"]**2 + pts["y"]**2 - 1
+
+ >>> pts = LabelTensor(torch.randn(100, 2), labels=["x", "y"])
+ >>> condition = Condition(input=pts, equation=Equation(dummy_equation))
+ """
+
+ # Available fields, input and equation data types
+ __fields__ = ["input", "equation"]
+ _avail_input_cls = (LabelTensor, Graph)
+ _avail_equation_cls = BaseEquation
+
+ def __new__(cls, input, equation):
+ """
+ Check the types of ``input`` and ``equation`` and instantiate an
+ instance of :class:`InputEquationCondition` accordingly.
+
+ :param input: The input data associated with the condition.
+ :type input: LabelTensor | Graph | list[Graph] | tuple[Graph]
+ :param BaseEquation equation: The equation associated with the
+ condition.
+ :raises ValueError: If ``input`` is not an instance of
+ :class:`~pina.label_tensor.LabelTensor`, or
+ :class:`~pina.graph.Graph`, nor a list or tuple of
+ :class:`~pina.graph.Graph`.
+ :raises ValueError: If ``equation`` is not an instance of
+ :class:`~pina.equation.base_equation.BaseEquation`.
+ :return: A new instance of :class:`InputEquationCondition`.
+ :rtype: InputEquationCondition
+ """
+ # Check input type - equation is checked in the setter
+ if isinstance(input, (list, tuple)):
+ check_consistency(input, Graph)
+ else:
+ check_consistency(input, cls._avail_input_cls)
+
+ return super().__new__(cls)
+
+ def store_data(self, **kwargs):
+ """
+ Store the input data in a dictionary-like structure.
+
+ :param dict kwargs: The keyword arguments containing the data to be
+ stored.
+ :return: A dictionary-like structure containing the stored data.
+ :rtype: _DataManager
+ """
+ # Save the equation as an attribute of the condition instance
+ setattr(self, "equation", kwargs.pop("equation"))
+
+ return _DataManager(**kwargs)
+
+ @property
+ def input(self):
+ """
+ The input data associated with the condition.
+
+ :return: The input data.
+ :rtype: LabelTensor | Graph | list[Graph] | tuple[Graph]
+ """
+ return self.data.input
+
+ @property
+ def equation(self):
+ """
+ The equation associated with the condition.
+
+ :return: The equation.
+ :rtype: BaseEquation
+ """
+ return self._equation
+
+ @equation.setter
+ def equation(self, value):
+ """
+ Set the equation associated with this condition.
+
+ :param BaseEquation value: The equation to associate with the condition.
+ :raises ValueError: If ``value`` is not an instance of
+ :class:`~pina.equation.base_equation.BaseEquation`.
+ """
+ # Check consistency
+ check_consistency(value, self._avail_equation_cls)
+ self._equation = value
+
+ def evaluate(self, batch, solver):
+ """
+ Evaluate the residual of the condition on the given batch using the
+ solver.
+
+ This method computes the non-aggregated, element-wise residual of the
+ condition. A forward pass of the solver's model is performed on the
+ input samples, and the condition residual is evaluated accordingly.
+
+ The returned tensor is not reduced, preserving the per-sample residual
+ values.
+
+ :param dict batch: The batch containing the data required by the
+ condition evaluation.
+ :param BaseSolver solver: The solver used to perform the forward pass
+ and compute the residual. The solver provides access to the model
+ and its parameters, which may be necessary for evaluating the
+ condition residual.
+ :return: The non-aggregated residual tensor.
+ :rtype: LabelTensor
+ """
+ # Compute residuals
+ samples = batch["input"].requires_grad_(True)
+ return self.equation.residual(
+ samples, solver.forward(samples), solver._params
+ )
diff --git a/pina/_src/condition/input_target_condition.py b/pina/_src/condition/input_target_condition.py
new file mode 100644
index 000000000..ead8cee3c
--- /dev/null
+++ b/pina/_src/condition/input_target_condition.py
@@ -0,0 +1,129 @@
+"""Module for the Input-Target Condition class."""
+
+import torch
+from torch_geometric.data import Data
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.core.graph import Graph
+from pina._src.condition.base_condition import BaseCondition
+from pina._src.data.manager.data_manager import _DataManager
+from pina._src.core.utils import check_consistency
+
+
+class InputTargetCondition(BaseCondition):
+ """
+ The :class:`InputTargetCondition` class represents a supervised condition
+ defined by both ``input`` and ``target`` data. The model is trained to
+ reproduce the ``target`` values given the ``input``. Supported data types
+ include :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`,
+ :class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
+
+ :Example:
+
+ >>> from pina import Condition, LabelTensor
+ >>> from pina.graph import Graph
+ >>> import torch
+
+ >>> pos = LabelTensor(torch.randn(100, 2), labels=["x", "y"])
+ >>> edge_index = torch.randint(0, 100, (2, 300))
+ >>> graph = Graph(pos=pos, edge_index=edge_index)
+
+ >>> input = LabelTensor(torch.randn(100, 2), labels=["x", "y"])
+ >>> condition = Condition(input=input, target=graph)
+ """
+
+ # Available fields, input, and target data types
+ __fields__ = ["input", "target"]
+ _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph)
+ _avail_target_cls = (torch.Tensor, LabelTensor, Data, Graph)
+
+ def __new__(cls, input, target):
+ """
+ Check the types of ``input`` and ``target`` data and instantiate an
+ instance of :class:`InputTargetCondition` accordingly.
+
+ :param input: The input data associated with the condition.
+ :type input: torch.Tensor | LabelTensor | Graph |
+ Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data]
+ :param target: The target data associated with the condition.
+ :type target: torch.Tensor | LabelTensor | Graph |
+ Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data]
+ :raises ValueError: If ``input`` is not of type :class:`torch.Tensor`,
+ :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`,
+ or :class:`~torch_geometric.data.Data`, nor is it a list or tuple of
+ :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`.
+ :raises ValueError: If ``target`` is not of type :class:`torch.Tensor`,
+ :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`,
+ or :class:`~torch_geometric.data.Data`, nor is it a list or tuple of
+ :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`.
+ :return: A new instance of :class:`InputTargetCondition`.
+ :rtype: InputTargetCondition
+ """
+ # Check input type - if iterable, ensure it is either Data or Graph
+ if isinstance(input, (list, tuple)):
+ check_consistency(input, (Data, Graph))
+ else:
+ check_consistency(input, cls._avail_input_cls)
+
+ # Check target type - if iterable, ensure it is either Data or Graph
+ if isinstance(target, (list, tuple)):
+ check_consistency(target, (Data, Graph))
+ else:
+ check_consistency(target, cls._avail_target_cls)
+
+ return super().__new__(cls)
+
+ def store_data(self, **kwargs):
+ """
+ Store the input and target data in a dictionary-like structure.
+
+ :param dict kwargs: The keyword arguments containing the data to be
+ stored.
+ :return: A dictionary-like structure containing the stored data.
+ :rtype: _DataManager
+ """
+ return _DataManager(**kwargs)
+
+ def evaluate(self, batch, solver):
+ """
+ Evaluate the residual of the condition on the given batch using the
+ solver.
+
+ This method computes the non-aggregated, element-wise residual of the
+ condition. A forward pass of the solver's model is performed on the
+ input samples, and the condition residual is evaluated accordingly.
+
+ The returned tensor is not reduced, preserving the per-sample residual
+ values.
+
+ :param dict batch: The batch containing the data required by the
+ condition evaluation.
+ :param BaseSolver solver: The solver used to perform the forward pass
+ and compute the residual. The solver provides access to the model
+ and its parameters, which may be necessary for evaluating the
+ condition residual.
+ :return: The non-aggregated residual tensor.
+ :rtype: torch.Tensor | LabelTensor
+ """
+ return solver.forward(batch["input"]) - batch["target"]
+
+ @property
+ def input(self):
+ """
+ The input data associated with the condition.
+
+ :return: The input data.
+ :rtype: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
+ list[Data] | tuple[Graph] | tuple[Data]
+ """
+ return self.data.input
+
+ @property
+ def target(self):
+ """
+ The target data associated with the condition.
+
+ :return: The target data.
+ :rtype: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
+ list[Data] | tuple[Graph] | tuple[Data]
+ """
+ return self.data.target
diff --git a/pina/_src/condition/time_series_condition.py b/pina/_src/condition/time_series_condition.py
new file mode 100644
index 000000000..28b38eaa6
--- /dev/null
+++ b/pina/_src/condition/time_series_condition.py
@@ -0,0 +1,247 @@
+"""Module for the TimeSeriesCondition class."""
+
+import torch
+from pina._src.core.utils import check_consistency, check_positive_integer
+from pina._src.data.manager.data_manager import _DataManager
+from pina._src.condition.base_condition import BaseCondition
+from pina._src.core.label_tensor import LabelTensor
+
+
+class TimeSeriesCondition(BaseCondition):
+ """
+ The :class:`TimeSeriesCondition` class represents an autoregressive time
+ series condition defined by temporal ``input`` data. The input is expected
+ to have shape ``[trajectories, time_steps, *features]``, where the second
+ dimension corresponds to the temporal evolution of each trajectory.
+
+ During training, the condition automatically extracts overlapping temporal
+ windows from the trajectories. The parameter ``unroll_length`` defines the
+ number of consecutive time steps contained in each temporal window, while
+ ``n_windows`` controls how many temporal windows are created from the
+ available trajectories.
+
+ Internally, the unrolled data is stored as a tensor of shape
+ ``[trajectories, n_windows, unroll_length, *features]``.
+
+ Supported data types include :class:`~pina.label_tensor.LabelTensor` and
+ :class:`torch.Tensor`.
+
+ :Example:
+
+ >>> from pina import Condition, LabelTensor
+ >>> import torch
+
+ >>> data = LabelTensor(torch.rand(5, 10, 2), labels=["u", "v"])
+ >>> condition = Condition(input=data, unroll_length=5, n_windows=3)
+ """
+
+ # Available fields and input data types
+ __fields__ = ["input", "unroll_length", "n_windows", "randomize"]
+ _avail_input_cls = (torch.Tensor, LabelTensor)
+
+ def __new__(cls, input, n_windows, unroll_length, randomize=False):
+ """
+ Validate the input data and time-series parameters.
+
+ :param input: The temporal input data.
+ :type input: torch.Tensor | LabelTensor
+ :param int n_windows: The maximum number of temporal windows to extract.
+ :param int unroll_length: The number of time steps in each window.
+ :param bool randomize: If ``True``, randomly permute the valid starting
+ indices before selecting the windows. Default is ``False``.
+ :raises ValueError: If ``input`` is not of type :class:`torch.Tensor` or
+ :class:`~pina.label_tensor.LabelTensor`.
+ :raises AssertionError: If ``unroll_length`` is not a positive integer.
+ :raises AssertionError: If ``n_windows`` is not a positive integer.
+ :raises ValueError: If ``randomize`` is not a boolean value.
+ :raises ValueError: If ``input`` has fewer than three dimensions.
+ :raises ValueError: If ``unroll_length`` is lower than 2.
+ :return: A new :class:`TimeSeriesCondition` instance.
+ :rtype: TimeSeriesCondition
+ """
+ # Check consistency
+ check_consistency(input, cls._avail_input_cls)
+ check_consistency(randomize, bool)
+ check_positive_integer(n_windows, strict=True)
+ check_positive_integer(unroll_length, strict=True)
+
+ # Validate input
+ if input.dim() < 3:
+ raise ValueError(
+ "The provided data tensor must have at least 3 dimensions: "
+ f"[trajectories, time, *features]. Got shape {input.shape}."
+ )
+
+ # Validate unroll_length
+ if unroll_length < 2:
+ raise ValueError(
+ f"unroll_length must be strictly greater than 1 to create "
+ f" temporal windows. Got unroll_length={unroll_length}."
+ )
+
+ return super().__new__(cls)
+
+ def store_data(self, **kwargs):
+ """
+ Store the unrolled time-series input data.
+
+ The method extracts the time-series input data and creates the temporal
+ windows based on the specified ``unroll_length`` and ``n_windows``.
+
+ :param dict kwargs: The keyword arguments containing the data to be
+ stored.
+ :return: A dictionary-like structure containing the stored data.
+ :rtype: _DataManager
+ """
+ # Extract unrolling parameters from kwargs
+ unroll_length = kwargs.get("unroll_length")
+ n_windows = kwargs.get("n_windows")
+ randomize = kwargs.get("randomize", False)
+ data = kwargs.get("input")
+
+ # Create unrolled windows from the input data
+ unrolled_data = self._unroll(
+ data=data,
+ n_windows=n_windows,
+ unroll_length=unroll_length,
+ randomize=randomize,
+ )
+
+ # Preserve labels if the input data is a LabelTensor
+ if isinstance(data, LabelTensor):
+ unrolled_data = unrolled_data.as_subclass(LabelTensor)
+ unrolled_data.labels = data.labels
+
+ return _DataManager(input=unrolled_data)
+
+ def _unroll(self, data, n_windows, unroll_length, randomize):
+ """
+ Build temporal windows from time-series data.
+
+ Given data with shape ``[trajectories, time_steps, *features]``, this
+ method returns a tensor of overlapping temporal windows with shape
+ ``[trajectories, windows, unroll_length, *features]``.
+
+ :param data: The temporal data tensor to be unrolled.
+ :type data: torch.Tensor | LabelTensor
+ :param int n_windows: The maximum number of temporal windows to extract.
+ :param int unroll_length: The number of time steps in each window.
+ :param bool randomize: If ``True``, starting indices are randomly
+ permuted before applying ``n_windows``. Default is ``True``.
+ :raises ValueError: If ``unroll_length`` is greater than the number of
+ time steps in the data.
+ :return: A tensor of unrolled windows.
+ :rtype: torch.Tensor | LabelTensor
+ """
+ # Store the number of time steps in the data
+ time_steps = data.shape[1]
+
+ # Compute the last valid starting index for unroll windows
+ last_idx = time_steps - unroll_length
+
+ # Raise error if unroll_length is greater than time_steps
+ if last_idx < 0:
+ raise ValueError(
+ f"Cannot create unroll windows: unroll_length {unroll_length} "
+ f"exceeds the available number of time steps {time_steps}."
+ )
+
+ # Extract starting indices
+ start_indices = torch.arange(last_idx + 1)
+
+ # Randomly permute starting indices if randomize is True
+ if randomize:
+ start_indices = start_indices[torch.randperm(len(start_indices))]
+
+ # Raise error if n_windows is greater than the number of valid windows
+ if len(start_indices) < n_windows:
+ raise ValueError(
+ f"Cannot create {n_windows} unroll windows with the selected "
+ f"unroll_length {unroll_length} from data with {time_steps} "
+ f"time steps. Only {len(start_indices)} valid windows are "
+ "available."
+ )
+
+ # Limit the number of windows to n_windows
+ start_indices = start_indices[:n_windows]
+
+ # Create unroll windows by slicing the input data at the starting idx
+ windows = [data[:, s : s + unroll_length] for s in start_indices]
+
+ if isinstance(data, LabelTensor):
+ # Preserve labels if the input data is a LabelTensor
+ unrolled_data = torch.stack(windows, dim=1).as_subclass(LabelTensor)
+ unrolled_data.labels = data.labels
+ else:
+ unrolled_data = torch.stack(windows, dim=1)
+
+ return unrolled_data
+
+ def evaluate(self, batch, solver):
+ """
+ Evaluate the residual of the condition on the given batch using the
+ solver.
+
+ This method computes the per-step residuals through autoregressive
+ unrolling. A forward pass of the solver's model is performed at each
+ time step, and the per-step residuals (predicted - target) are
+ returned as a stacked tensor.
+
+ The returned tensor preserves all per-step residual values without
+ reduction or loss aggregation.
+
+ :param dict batch: The batch containing the data required by the
+ condition evaluation.
+ :param BaseSolver solver: The solver used to perform the forward pass
+ and compute the residual. The solver provides access to the model
+ and its parameters, which may be necessary for evaluating the
+ condition residual.
+ :raises ValueError: If the input tensor in the batch has less than 4
+ dimensions.
+ :return: The stacked per-step residual tensor of shape
+ ``[time_steps - 1, trajectories, windows, *features]``.
+ :rtype: torch.Tensor | LabelTensor
+ """
+ # Raise error if input tensor does not have at least 4 dimensions
+ if batch["input"].dim() < 4:
+ raise ValueError(
+ "The provided input tensor must have at least 4 dimensions:"
+ " [trajectories, windows, time_steps, *features]."
+ f" Got shape {batch['input'].shape}."
+ )
+
+ # Copy the kwargs to avoid modifying the original settings
+ kwargs = solver._kwargs.copy()
+
+ # Extract the initial state and initialize the step-wise residuals list
+ current_state = batch["input"][:, :, 0]
+ residuals = []
+
+ # Iterate over the time steps
+ for step in range(1, batch["input"].shape[2]):
+
+ # Pre-process, forward, and post-process the current state
+ processed_input = solver.preprocess_step(current_state, **kwargs)
+ output = solver.forward(processed_input)
+ predicted_state = solver.postprocess_step(output, **kwargs)
+
+ # Retrieve the target and compute the step-wise residual
+ target_state = batch["input"][:, :, step]
+ step_residual = predicted_state - target_state
+ residuals.append(step_residual)
+
+ # Update the current state for the next iteration
+ current_state = predicted_state
+
+ # Stack the step-wise residuals
+ return torch.stack(residuals).as_subclass(torch.Tensor)
+
+ @property
+ def input(self):
+ """
+ The unrolled temporal input data.
+
+ :return: The input data.
+ :rtype: torch.Tensor | LabelTensor
+ """
+ return self.data.input
diff --git a/pina/_src/core/__init__.py b/pina/_src/core/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pina/_src/core/graph.py b/pina/_src/core/graph.py
new file mode 100644
index 000000000..4b0a2fcb0
--- /dev/null
+++ b/pina/_src/core/graph.py
@@ -0,0 +1,421 @@
+"""Module to build Graph objects and perform operations on them."""
+
+import torch
+from torch_geometric.data import Data, Batch
+from torch_geometric.utils import to_undirected
+from torch_geometric.utils.loop import remove_self_loops
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.core.utils import check_consistency, is_function
+
+
+class Graph(Data):
+ """
+ Extends :class:`~torch_geometric.data.Data` class to include additional
+ checks and functionalities.
+ """
+
+ def __new__(
+ cls,
+ **kwargs,
+ ):
+ """
+ Create a new instance of the :class:`~pina.graph.Graph` class by
+ checking the consistency of the input data and storing the attributes.
+
+ :param dict kwargs: Parameters used to initialize the
+ :class:`~pina.graph.Graph` object.
+ :return: A new instance of the :class:`~pina.graph.Graph` class.
+ :rtype: Graph
+ """
+ # create class instance
+ instance = Data.__new__(cls)
+
+ # check the consistency of types defined in __init__, the others are not
+ # checked (as in pyg Data object)
+ instance._check_type_consistency(**kwargs)
+
+ return instance
+
+ def __init__(
+ self,
+ x=None,
+ edge_index=None,
+ pos=None,
+ edge_attr=None,
+ undirected=False,
+ **kwargs,
+ ):
+ """
+ Initialize the object by setting the node features, edge index,
+ edge attributes, and positions. The edge index is preprocessed to make
+ the graph undirected if required. For more details, see the
+ :meth:`torch_geometric.data.Data`
+
+ :param x: Optional tensor of node features ``(N, F)`` where ``F`` is the
+ number of features per node.
+ :type x: torch.Tensor, LabelTensor
+ :param torch.Tensor edge_index: A tensor of shape ``(2, E)``
+ representing the indices of the graph's edges.
+ :param pos: A tensor of shape ``(N, D)`` representing the positions of
+ ``N`` points in ``D``-dimensional space.
+ :type pos: torch.Tensor | LabelTensor
+ :param edge_attr: Optional tensor of edge_featured ``(E, F')`` where
+ ``F'`` is the number of edge features
+ :type edge_attr: torch.Tensor | LabelTensor
+ :param bool undirected: Whether to make the graph undirected
+ :param dict kwargs: Additional keyword arguments passed to the
+ :class:`~torch_geometric.data.Data` class constructor.
+ """
+ # preprocessing
+ self._preprocess_edge_index(edge_index, undirected)
+
+ # calling init
+ super().__init__(
+ x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos, **kwargs
+ )
+
+ def _check_type_consistency(self, **kwargs):
+ """
+ Check the consistency of the types of the input data.
+
+ :param dict kwargs: Attributes to be checked for consistency.
+ """
+ # default types, specified in cls.__new__, by default they are Nont
+ # if specified in **kwargs they get override
+ x, pos, edge_index, edge_attr = None, None, None, None
+ if "pos" in kwargs:
+ pos = kwargs["pos"]
+ self._check_pos_consistency(pos)
+ if "edge_index" in kwargs:
+ edge_index = kwargs["edge_index"]
+ self._check_edge_index_consistency(edge_index)
+ if "x" in kwargs:
+ x = kwargs["x"]
+ # self._check_x_consistency(x, pos)
+ if "edge_attr" in kwargs:
+ edge_attr = kwargs["edge_attr"]
+ self._check_edge_attr_consistency(edge_attr, edge_index)
+ if "undirected" in kwargs:
+ undirected = kwargs["undirected"]
+ check_consistency(undirected, bool)
+
+ @staticmethod
+ def _check_pos_consistency(pos):
+ """
+ Check if the position tensor is consistent.
+ :param torch.Tensor pos: The position tensor.
+ :raises ValueError: If the position tensor is not consistent.
+ """
+ if pos is not None:
+ check_consistency(pos, (torch.Tensor, LabelTensor))
+ if pos.ndim != 2:
+ raise ValueError("pos must be a 2D tensor.")
+
+ @staticmethod
+ def _check_edge_index_consistency(edge_index):
+ """
+ Check if the edge index is consistent.
+
+ :param torch.Tensor edge_index: The edge index tensor.
+ :raises ValueError: If the edge index tensor is not consistent.
+ """
+ check_consistency(edge_index, (torch.Tensor, LabelTensor))
+ if edge_index.ndim != 2:
+ raise ValueError("edge_index must be a 2D tensor.")
+ if edge_index.size(0) != 2:
+ raise ValueError("edge_index must have shape [2, num_edges].")
+
+ @staticmethod
+ def _check_edge_attr_consistency(edge_attr, edge_index):
+ """
+ Check if the edge attribute tensor is consistent in type and shape
+ with the edge index.
+
+ :param edge_attr: The edge attribute tensor.
+ :type edge_attr: torch.Tensor | LabelTensor
+ :param torch.Tensor edge_index: The edge index tensor.
+ :raises ValueError: If the edge attribute tensor is not consistent.
+ """
+ if edge_attr is not None:
+ check_consistency(edge_attr, (torch.Tensor, LabelTensor))
+ if edge_attr.ndim != 2:
+ raise ValueError("edge_attr must be a 2D tensor.")
+ if edge_attr.size(0) != edge_index.size(1):
+ raise ValueError(
+ "edge_attr must have shape "
+ "[num_edges, num_edge_features], expected "
+ f"num_edges {edge_index.size(1)} "
+ f"got {edge_attr.size(0)}."
+ )
+
+ @staticmethod
+ def _check_x_consistency(x, pos=None):
+ """
+ Check if the input tensor x is consistent with the position tensor
+ `pos`.
+
+ :param x: The input tensor.
+ :type x: torch.Tensor | LabelTensor
+ :param pos: The position tensor.
+ :type pos: torch.Tensor | LabelTensor
+ :raises ValueError: If the input tensor is not consistent.
+ """
+ if x is not None:
+ check_consistency(x, (torch.Tensor, LabelTensor))
+ if x.ndim != 2:
+ raise ValueError("x must be a 2D tensor.")
+ if pos is not None:
+ if x.size(0) != pos.size(0):
+ raise ValueError("Inconsistent number of nodes.")
+
+ @staticmethod
+ def _preprocess_edge_index(edge_index, undirected):
+ """
+ Preprocess the edge index to make the graph undirected (if required).
+
+ :param torch.Tensor edge_index: The edge index.
+ :param bool undirected: Whether the graph is undirected.
+ :return: The preprocessed edge index.
+ :rtype: torch.Tensor
+ """
+ if undirected:
+ edge_index = to_undirected(edge_index)
+ return edge_index
+
+ def extract(self, labels, attr="x"):
+ """
+ Perform extraction of labels from the attribute specified by `attr`.
+
+ :param labels: Labels to extract
+ :type labels: list[str] | tuple[str] | str | dict
+ :return: Batch object with extraction performed on x
+ :rtype: PinaBatch
+ """
+ # Extract labels from LabelTensor object
+ tensor = getattr(self, attr).extract(labels)
+ # Set the extracted tensor as the new attribute
+ setattr(self, attr, tensor)
+ return self
+
+
+class GraphBuilder:
+ """
+ A class that allows an easy definition of :class:`Graph` instances.
+ """
+
+ def __new__(
+ cls,
+ pos,
+ edge_index,
+ x=None,
+ edge_attr=False,
+ custom_edge_func=None,
+ loop=True,
+ **kwargs,
+ ):
+ """
+ Compute the edge attributes and create a new instance of the
+ :class:`~pina.graph.Graph` class.
+
+ :param pos: A tensor of shape ``(N, D)`` representing the positions of
+ ``N`` points in ``D``-dimensional space.
+ :type pos: torch.Tensor or LabelTensor
+ :param edge_index: A tensor of shape ``(2, E)`` representing the indices
+ of the graph's edges.
+ :type edge_index: torch.Tensor
+ :param x: Optional tensor of node features of shape ``(N, F)``, where
+ ``F`` is the number of features per node.
+ :type x: torch.Tensor | LabelTensor, optional
+ :param bool edge_attr: Whether to compute the edge attributes.
+ :param custom_edge_func: A custom function to compute edge attributes.
+ If provided, overrides ``edge_attr``.
+ :type custom_edge_func: Callable, optional
+ :param bool loop: Whether to include self-loops.
+ :param kwargs: Additional keyword arguments passed to the
+ :class:`~pina.graph.Graph` class constructor.
+ :return: A :class:`~pina.graph.Graph` instance constructed using the
+ provided information.
+ :rtype: Graph
+ """
+ if not loop:
+ edge_index = remove_self_loops(edge_index)[0]
+ edge_attr = cls._create_edge_attr(
+ pos, edge_index, edge_attr, custom_edge_func or cls._build_edge_attr
+ )
+ return Graph(
+ x=x,
+ edge_index=edge_index,
+ edge_attr=edge_attr,
+ pos=pos,
+ **kwargs,
+ )
+
+ @staticmethod
+ def _create_edge_attr(pos, edge_index, edge_attr, func):
+ """
+ Create the edge attributes based on the input parameters.
+
+ :param pos: Positions of the points.
+ :type pos: torch.Tensor | LabelTensor
+ :param torch.Tensor edge_index: Edge indices.
+ :param bool edge_attr: Whether to compute the edge attributes.
+ :param Callable func: Function to compute the edge attributes.
+ :raises ValueError: If ``func`` is not a function.
+ :return: The edge attributes.
+ :rtype: torch.Tensor | LabelTensor | None
+ """
+ check_consistency(edge_attr, bool)
+ if edge_attr:
+ if is_function(func):
+ return func(pos, edge_index)
+ raise ValueError("custom_edge_func must be a function.")
+ return None
+
+ @staticmethod
+ def _build_edge_attr(pos, edge_index):
+ """
+ Default function to compute the edge attributes.
+
+ :param pos: Positions of the points.
+ :type pos: torch.Tensor | LabelTensor
+ :param torch.Tensor edge_index: Edge indices.
+ :return: The edge attributes.
+ :rtype: torch.Tensor
+ """
+ return (
+ (pos[edge_index[0]] - pos[edge_index[1]])
+ .abs()
+ .as_subclass(torch.Tensor)
+ )
+
+
+class RadiusGraph(GraphBuilder):
+ """
+ Extends the :class:`~pina.graph.GraphBuilder` class to compute
+ ``edge_index`` based on a radius. Each point is connected to all the points
+ within the radius.
+ """
+
+ def __new__(cls, pos, radius, **kwargs):
+ """
+ Instantiate the :class:`~pina.graph.Graph` class by computing the
+ ``edge_index`` based on the radius provided.
+
+ :param pos: A tensor of shape ``(N, D)`` representing the positions of
+ ``N`` points in ``D``-dimensional space.
+ :type pos: torch.Tensor | LabelTensor
+ :param float radius: The radius within which points are connected.
+ :param dict kwargs: The additional keyword arguments to be passed to
+ :class:`GraphBuilder` and :class:`Graph` classes.
+ :return: A :class:`~pina.graph.Graph` instance with the computed
+ ``edge_index``.
+ :rtype: Graph
+ """
+ edge_index = cls.compute_radius_graph(pos, radius)
+ return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs)
+
+ @staticmethod
+ def compute_radius_graph(points, radius):
+ """
+ Computes the ``edge_index`` based on the radius. Each point is connected
+ to all the points within the radius.
+
+ :param points: A tensor of shape ``(N, D)`` representing the positions
+ of ``N`` points in ``D``-dimensional space.
+ :type points: torch.Tensor | LabelTensor
+ :param float radius: The radius within which points are connected.
+ :return: A tensor of shape ``(2, E)``, with ``E`` number of edges,
+ representing the edge indices of the graph.
+ :rtype: torch.Tensor
+ """
+ dist = torch.cdist(points, points, p=2)
+ return (
+ torch.nonzero(dist <= radius, as_tuple=False)
+ .t()
+ .as_subclass(torch.Tensor)
+ )
+
+
+class KNNGraph(GraphBuilder):
+ """
+ Extends the :class:`~pina.graph.GraphBuilder` class to compute
+ ``edge_index`` based on a K-nearest neighbors algorithm.
+ """
+
+ def __new__(cls, pos, neighbours, **kwargs):
+ """
+ Instantiate the :class:`~pina.graph.Graph` class by computing the
+ ``edge_index`` based on the K-nearest neighbors algorithm.
+
+ :param pos: A tensor of shape ``(N, D)`` representing the positions of
+ ``N`` points in ``D``-dimensional space.
+ :type pos: torch.Tensor | LabelTensor
+ :param int neighbours: The number of nearest neighbors to consider when
+ building the graph.
+ :param dict kwargs: The additional keyword arguments to be passed to
+ :class:`GraphBuilder` and :class:`Graph` classes.
+
+ :return: A :class:`~pina.graph.Graph` instance with the computed
+ ``edge_index``.
+ :rtype: Graph
+ """
+
+ edge_index = cls.compute_knn_graph(pos, neighbours)
+ return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs)
+
+ @staticmethod
+ def compute_knn_graph(points, neighbours):
+ """
+ Computes the ``edge_index`` based on the K-nearest neighbors algorithm.
+
+ :param points: A tensor of shape ``(N, D)`` representing the positions
+ of ``N`` points in ``D``-dimensional space.
+ :type points: torch.Tensor | LabelTensor
+ :param int neighbours: The number of nearest neighbors to consider when
+ building the graph.
+ :return: A tensor of shape ``(2, E)``, with ``E`` number of edges,
+ representing the edge indices of the graph.
+ :rtype: torch.Tensor
+ """
+ dist = torch.cdist(points, points, p=2)
+ knn_indices = torch.topk(dist, k=neighbours, largest=False).indices
+ row = torch.arange(points.size(0)).repeat_interleave(neighbours)
+ col = knn_indices.flatten()
+ return torch.stack([row, col], dim=0).as_subclass(torch.Tensor)
+
+
+class LabelBatch(Batch):
+ """
+ Extends the :class:`~torch_geometric.data.Batch` class to include
+ :class:`~pina.label_tensor.LabelTensor` objects.
+ """
+
+ @classmethod
+ def from_data_list(cls, data_list):
+ """
+ Create a Batch object from a list of :class:`~torch_geometric.data.Data`
+ or :class:`~pina.graph.Graph` objects.
+
+ :param data_list: List of :class:`~torch_geometric.data.Data` or
+ :class:`~pina.graph.Graph` objects.
+ :type data_list: list[Data] | list[Graph]
+ :return: A :class:`~torch_geometric.data.Batch` object containing
+ the input data.
+ :rtype: :class:`~torch_geometric.data.Batch`
+ """
+ # Store the labels of Data/Graph objects (all data have the same labels)
+ # If the data do not contain labels, labels is an empty dictionary,
+ # therefore the labels are not stored
+ labels = {
+ k: v.labels
+ for k, v in data_list[0].items()
+ if isinstance(v, LabelTensor)
+ }
+
+ # Create a Batch object from the list of Data objects
+ batch = super().from_data_list(data_list)
+
+ # Put the labels back in the Batch object
+ for k, v in labels.items():
+ batch[k].labels = v
+ return batch
diff --git a/pina/_src/core/label_tensor.py b/pina/_src/core/label_tensor.py
new file mode 100644
index 000000000..41bccc6fc
--- /dev/null
+++ b/pina/_src/core/label_tensor.py
@@ -0,0 +1,753 @@
+"""Module for LabelTensor"""
+
+from copy import copy, deepcopy
+import torch
+from torch import Tensor
+
+
+class LabelTensor(torch.Tensor):
+ """
+ Extension of the :class:`torch.Tensor` class that includes labels for
+ each dimension.
+ """
+
+ @staticmethod
+ def __new__(cls, x, labels, *args, **kwargs):
+ """
+ Create a new instance of the :class:`~pina.label_tensor.LabelTensor`
+ class.
+
+ :param torch.Tensor x: :class:`torch.tensor` instance to be casted as a
+ :class:`~pina.label_tensor.LabelTensor`.
+ :param labels: Labels to assign to the tensor.
+ :type labels: str | list[str] | dict
+ :return: The instance of the :class:`~pina.label_tensor.LabelTensor`
+ class.
+ :rtype: LabelTensor
+ """
+
+ if isinstance(x, LabelTensor):
+ return x
+ return super().__new__(cls, x, *args, **kwargs)
+
+ @property
+ def tensor(self):
+ """
+ Returns the tensor part of the :class:`~pina.label_tensor.LabelTensor`
+ object.
+
+ :return: Tensor part of the :class:`~pina.label_tensor.LabelTensor`.
+ :rtype: torch.Tensor
+ """
+
+ return self.as_subclass(Tensor)
+
+ def __init__(self, x, labels):
+ """
+ Initialize the :class:`~pina.label_tensor.LabelTensor` instance, by
+ checking the consistency of the labels and the tensor. Specifically, the
+ labels must match the following conditions:
+
+ - At each dimension, the number of labels must match the size of the \
+ dimension.
+ - At each dimension, the labels must be unique.
+
+ The labels can be passed in the following formats:
+
+ :Example:
+ >>> from pina import LabelTensor
+ >>> tensor = LabelTensor(
+ >>> torch.rand((2000, 3)),
+ ... {1: {"name": "space", "dof": ['a', 'b', 'c']}})
+ >>> tensor = LabelTensor(
+ >>> torch.rand((2000, 3)),
+ ... ["a", "b", "c"])
+
+ The keys of the dictionary are the dimension indices, and the values are
+ dictionaries containing the labels and the name of the dimension. If
+ the labels are passed as a list, these are assigned to the last
+ dimension.
+
+ :param torch.Tensor x: The tensor to be casted as a
+ :class:`~pina.label_tensor.LabelTensor`.
+ :param labels: Labels to assign to the tensor.
+ :type labels: str | list[str] | dict
+ :raises ValueError: If the labels are not consistent with the tensor.
+ """
+ super().__init__()
+ if labels is not None:
+ self.labels = labels
+ else:
+ self._labels = {}
+
+ @property
+ def full_labels(self):
+ """
+ Returns the full labels of the tensor, even for the dimensions that are
+ not labeled.
+
+ :return: The full labels of the tensor
+ :rtype: dict
+ """
+ to_return_dict = {}
+ shape_tensor = self.shape
+ for i, value in enumerate(shape_tensor):
+ if i in self._labels:
+ to_return_dict[i] = self._labels[i]
+ else:
+ to_return_dict[i] = {"dof": range(value), "name": i}
+ return to_return_dict
+
+ @property
+ def stored_labels(self):
+ """
+ Returns the labels stored inside the instance.
+
+ :return: The labels stored inside the instance.
+ :rtype: dict
+ """
+ return self._labels
+
+ @property
+ def labels(self):
+ """
+ Returns the labels of the last dimension of the instance.
+
+ :return: labels of last dimension
+ :rtype: list
+ """
+ if self.ndim - 1 in self._labels:
+ return self._labels[self.ndim - 1]["dof"]
+ return None
+
+ @labels.setter
+ def labels(self, labels):
+ """
+ Set labels stored insider the instance by checking the type of the
+ input labels and handling it accordingly. The following types are
+ accepted:
+
+ - **list**: The list of labels is assigned to the last dimension.
+ - **dict**: The dictionary of labels is assigned to the tensor.
+ - **str**: The string is assigned to the last dimension.
+
+ :param labels: Labels to assign to the class variable _labels.
+ :type labels: str | list[str] | dict
+ """
+
+ if not hasattr(self, "_labels"):
+ self._labels = {}
+ if isinstance(labels, dict):
+ self._init_labels_from_dict(labels)
+ elif isinstance(labels, (list, range)):
+ self._init_labels_from_list(labels)
+ elif isinstance(labels, str):
+ labels = [labels]
+ self._init_labels_from_list(labels)
+ else:
+ raise ValueError("labels must be list, dict or string.")
+
+ def _init_labels_from_dict(self, labels):
+ """
+ Store the internal label representation according to the values
+ passed as input.
+
+ :param dict labels: The label(s) to update.
+ :raises ValueError: If the dof list contains duplicates or the number of
+ dof does not match the tensor shape.
+ """
+
+ tensor_shape = self.shape
+
+ def validate_dof(dof_list, dim_size):
+ """Validate the 'dof' list for uniqueness and size."""
+ if len(dof_list) != len(set(dof_list)):
+ raise ValueError("dof must be unique")
+ if len(dof_list) != dim_size:
+ raise ValueError(
+ f"Number of dof ({len(dof_list)}) does not match "
+ f"tensor shape ({dim_size})"
+ )
+
+ for dim, label in labels.items():
+ if isinstance(label, dict):
+ if "name" not in label:
+ label["name"] = dim
+ if "dof" not in label:
+ label["dof"] = range(tensor_shape[dim])
+ if "dof" in label and "name" in label:
+ dof = label["dof"]
+ dof_list = dof if isinstance(dof, (list, range)) else [dof]
+ if not isinstance(dof_list, (list, range)):
+ raise ValueError(
+ f"'dof' should be a list or range, not"
+ f" {type(dof_list)}"
+ )
+ validate_dof(dof_list, tensor_shape[dim])
+ else:
+ raise ValueError(
+ "Labels dictionary must contain either "
+ " both 'name' and 'dof' keys"
+ )
+ else:
+ raise ValueError(
+ f"Invalid label format for {dim}: Expected "
+ f"list or dictionary, got {type(label)}"
+ )
+
+ # Assign validated label data to internal labels
+ self._labels[dim] = label
+
+ def _init_labels_from_list(self, labels):
+ """
+ Given a list of dof, this method update the internal label
+ representation by assigning the dof to the last dimension.
+
+ :param labels: The label(s) to update.
+ :type labels: list
+ """
+
+ # Create a dict with labels
+ last_dim_labels = {
+ self.ndim - 1: {"dof": labels, "name": self.ndim - 1}
+ }
+ self._init_labels_from_dict(last_dim_labels)
+
+ def extract(self, labels_to_extract):
+ """
+ Extract the subset of the original tensor by returning all the positions
+ corresponding to the passed ``label_to_extract``. If
+ ``label_to_extract`` is a dictionary, the keys are the dimension names
+ and the values are the labels to extract. If a single label or a list
+ of labels is passed, the last dimension is considered.
+
+ :Example:
+ >>> from pina import LabelTensor
+ >>> labels = {1: {'dof': ["a", "b", "c"], 'name': 'space'}}
+ >>> tensor = LabelTensor(torch.rand((2000, 3)), labels)
+ >>> tensor.extract("a")
+ >>> tensor.extract(["a", "b"])
+ >>> tensor.extract({"space": ["a", "b"]})
+
+ :param labels_to_extract: The label(s) to extract.
+ :type labels_to_extract: str | list[str] | tuple[str] | dict
+ :return: The extracted tensor with the updated labels.
+ :rtype: LabelTensor
+
+ :raises TypeError: Labels are not ``str``, ``list[str]`` or ``dict``
+ properly setted.
+ :raises ValueError: Label to extract is not in the labels ``list``.
+ """
+
+ def get_label_indices(dim_labels, labels_te):
+ if isinstance(labels_te, (int, str)):
+ labels_te = [labels_te]
+ return (
+ [dim_labels.index(label) for label in labels_te]
+ if len(labels_te) > 1
+ else slice(
+ dim_labels.index(labels_te[0]),
+ dim_labels.index(labels_te[0]) + 1,
+ )
+ )
+
+ # Ensure labels_to_extract is a list or dict
+ if isinstance(labels_to_extract, (str, int)):
+ labels_to_extract = [labels_to_extract]
+
+ labels = copy(self._labels)
+
+ # Get the dimension names and the respective dimension index
+ dim_names = {labels[dim]["name"]: dim for dim in labels}
+ ndim = super().ndim
+ tensor = self.tensor.as_subclass(torch.Tensor)
+
+ # Convert list/tuple to a dict for the last dimension if applicable
+ if isinstance(labels_to_extract, (list, tuple)):
+ last_dim = ndim - 1
+ dim_name = labels[last_dim]["name"]
+ labels_to_extract = {dim_name: list(labels_to_extract)}
+
+ # Validate the labels_to_extract type
+ if not isinstance(labels_to_extract, dict):
+ raise ValueError(
+ "labels_to_extract must be a string, list, or dictionary."
+ )
+
+ # Perform the extraction for each specified dimension
+ for dim_name, labels_te in labels_to_extract.items():
+ if dim_name not in dim_names:
+ raise ValueError(
+ f"Cannot extract labels for dimension '{dim_name}' as it is"
+ f" not present in the original labels."
+ )
+
+ idx_dim = dim_names[dim_name]
+ dim_labels = labels[idx_dim]["dof"]
+ indices = get_label_indices(dim_labels, labels_te)
+
+ extractor = [slice(None)] * ndim
+ extractor[idx_dim] = indices
+ tensor = tensor[tuple(extractor)]
+
+ labels[idx_dim] = {"dof": labels_te, "name": dim_name}
+
+ return LabelTensor(tensor, labels)
+
+ def __str__(self):
+ """
+ The string representation of the
+ :class:`~pina.label_tensor.LabelTensor`.
+
+ :return: String representation of the
+ :class:`~pina.label_tensor.LabelTensor` instance.
+ :rtype: str
+ """
+
+ s = ""
+ for key, value in self._labels.items():
+ s += f"{key}: {value}\n"
+ s += "\n"
+ s += self.tensor.__str__()
+ return s
+
+ @staticmethod
+ def cat(tensors, dim=0):
+ """
+ Concatenate a list of tensors along a specified dimension. For more
+ details, see :meth:`torch.cat`.
+
+ :param list[LabelTensor] tensors:
+ :class:`~pina.label_tensor.LabelTensor` instances to concatenate
+ :param int dim: Dimensions on which you want to perform the operation
+ (default is 0)
+ :return: A new :class:`LabelTensor` instance obtained by concatenating
+ the input instances.
+
+ :rtype: LabelTensor
+ :raises ValueError: either number dof or dimensions names differ.
+ """
+
+ if not tensors:
+ return [] # Handle empty list
+ if len(tensors) == 1:
+ return tensors[0] # Return single tensor as-is
+
+ # Perform concatenation
+ cat_tensor = torch.cat(tensors, dim=dim)
+ tensors_labels = [tensor.stored_labels for tensor in tensors]
+
+ # Check label consistency across tensors, excluding the
+ # concatenation dimension
+ for key in tensors_labels[0]:
+ if key != dim:
+ if any(
+ tensors_labels[i][key] != tensors_labels[0][key]
+ for i in range(len(tensors_labels))
+ ):
+ raise RuntimeError(
+ f"Tensors must have the same labels along all "
+ f"dimensions except {dim}."
+ )
+
+ # Copy and update the 'dof' for the concatenation dimension
+ cat_labels = {k: copy(v) for k, v in tensors_labels[0].items()}
+
+ # Update labels if the concatenation dimension has labels
+ if dim in tensors[0].stored_labels:
+ if dim in cat_labels:
+ cat_dofs = [label[dim]["dof"] for label in tensors_labels]
+ cat_labels[dim]["dof"] = sum(cat_dofs, [])
+ else:
+ cat_labels = tensors[0].stored_labels
+
+ # Assign updated labels to the concatenated tensor
+ cat_tensor._labels = cat_labels
+ return cat_tensor
+
+ @staticmethod
+ def stack(tensors):
+ """
+ Stacks a list of tensors along a new dimension. For more details, see
+ :meth:`torch.stack`.
+
+ :param list[LabelTensor] tensors: A list of tensors to stack.
+ All tensors must have the same shape.
+ :return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
+ by stacking the input tensors.
+ :rtype: LabelTensor
+ """
+
+ # Perform stacking in torch
+ new_tensor = torch.stack(tensors)
+
+ # Increase labels keys by 1
+ labels = tensors[0]._labels
+ labels = {key + 1: value for key, value in labels.items()}
+ new_tensor._labels = labels
+ return new_tensor
+
+ def requires_grad_(self, mode=True):
+ """
+ Override the :meth:`~torch.Tensor.requires_grad_` method to handle
+ the labels in the new tensor.
+ For more details, see :meth:`~torch.Tensor.requires_grad_`.
+
+ :param bool mode: A boolean value indicating whether the tensor should
+ track gradients.If `True`, the tensor will track gradients;
+ if `False`, it will not.
+ :return: The :class:`~pina.label_tensor.LabelTensor` itself with the
+ updated ``requires_grad`` state and retained labels.
+ :rtype: LabelTensor
+ """
+
+ lt = super().requires_grad_(mode)
+ lt._labels = self._labels
+ return lt
+
+ @property
+ def dtype(self):
+ """
+ Give the ``dtype`` of the tensor. For more details, see
+ :meth:`torch.dtype`.
+
+ :return: The data type of the tensor.
+ :rtype: torch.dtype
+ """
+
+ return super().dtype
+
+ def to(self, *args, **kwargs):
+ """
+ Performs Tensor dtype and/or device conversion. For more details, see
+ :meth:`torch.Tensor.to`.
+
+ :return: A new :class:`~pina.label_tensor.LabelTensor` instance with the
+ updated dtype and/or device and retained labels.
+ :rtype: LabelTensor
+ """
+
+ lt = super().to(*args, **kwargs)
+ lt._labels = self._labels
+ return lt
+
+ def clone(self, *args, **kwargs):
+ """
+ Clone the :class:`~pina.label_tensor.LabelTensor`. For more details, see
+ :meth:`torch.Tensor.clone`.
+
+ :return: A new :class:`~pina.label_tensor.LabelTensor` instance with the
+ same data and labels but allocated in a different memory location.
+ :rtype: LabelTensor
+ """
+
+ out = LabelTensor(
+ super().clone(*args, **kwargs), deepcopy(self._labels)
+ )
+ return out
+
+ def append(self, tensor, mode="std"):
+ """
+ Appends a given tensor to the current tensor along the last dimension.
+ This method supports two types of appending operations:
+
+ 1. **Standard append** ("std"): Concatenates the input tensor with the \
+ current tensor along the last dimension.
+ 2. **Cross append** ("cross"): Creates a cross-product of the current \
+ tensor and the input tensor.
+
+ :param tensor: The tensor to append to the current tensor.
+ :type tensor: LabelTensor
+ :param mode: The append mode to use. Defaults to ``st``.
+ :type mode: str, optional
+ :return: A new :class:`LabelTensor` instance obtained by appending the
+ input tensor.
+ :rtype: LabelTensor
+
+ :raises ValueError: If the mode is not "std" or "cross".
+ """
+
+ if mode == "std":
+ # Call cat on last dimension
+ new_label_tensor = LabelTensor.cat(
+ [self, tensor], dim=self.ndim - 1
+ )
+ return new_label_tensor
+ if mode == "cross":
+ # Crete tensor and call cat on last dimension
+ tensor1 = self
+ tensor2 = tensor
+ n1 = tensor1.shape[0]
+ n2 = tensor2.shape[0]
+ tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
+ tensor2 = LabelTensor(
+ tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels
+ )
+ new_label_tensor = LabelTensor.cat(
+ [tensor1, tensor2], dim=self.ndim - 1
+ )
+ return new_label_tensor
+ raise ValueError('mode must be either "std" or "cross"')
+
+ @staticmethod
+ def vstack(tensors):
+ """
+ Stack tensors vertically. For more details, see :meth:`torch.vstack`.
+
+ :param list of LabelTensor label_tensors: The
+ :class:`~pina.label_tensor.LabelTensor` instances to stack. They
+ need to have equal labels.
+ :return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
+ by stacking the input tensors vertically.
+ :rtype: LabelTensor
+ """
+
+ return LabelTensor.cat(tensors, dim=0)
+
+ # This method is used to update labels
+ def _update_single_label(self, index, dim):
+ """
+ Update the labels of the tensor based on the index (or list of indices).
+
+ :param index: Index of dof to retain.
+ :type index: int | slice | list[int] | tuple[int] | torch.Tensor
+ :param int dim: Dimension of the indexes in the original tensor.
+ :return: The updated labels for the specified dimension.
+ :rtype: list[int]
+ :raises: ValueError: If the index type is not supported.
+ """
+ old_dof = self._labels[dim]["dof"]
+ # Handle slicing
+ if isinstance(index, slice):
+ new_dof = old_dof[index]
+ # Handle single integer index
+ elif isinstance(index, int):
+ new_dof = [old_dof[index]]
+ # Handle lists or tensors
+ elif isinstance(index, (list, torch.Tensor)):
+ # Handle list of bools
+ if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
+ index = index.nonzero().squeeze()
+ new_dof = (
+ [old_dof[i] for i in index]
+ if isinstance(old_dof, list)
+ else index
+ )
+ else:
+ raise NotImplementedError(
+ f"Unsupported index type: {type(index)}. Expected slice, int, "
+ f"list, or torch.Tensor."
+ )
+ return new_dof
+
+ def __getitem__(self, index):
+ """
+ Override the __getitem__ method to handle the labels of the
+ :class:`~pina.label_tensor.LabelTensor` instance. It first performs
+ __getitem__ operation on the :class:`torch.Tensor` part of the instance,
+ then updates the labels based on the index.
+
+ :param index: The index used to access the item
+ :type index: int | str | tuple of int | list ot int | torch.Tensor
+ :return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
+ `__getitem__` operation on :class:`torch.Tensor` part of the
+ instance, with the updated labels.
+ :rtype: LabelTensor
+
+ :raises KeyError: If an invalid label index is provided.
+ :raises IndexError: If an invalid index is accessed in the tensor.
+ """
+
+ # Handle string index
+ if isinstance(index, str) or (
+ isinstance(index, (tuple, list))
+ and all(isinstance(i, str) for i in index)
+ ):
+ return self.extract(index)
+
+ # Retrieve selected tensor and labels
+ selected_tensor = super().__getitem__(index)
+ if not hasattr(self, "_labels"):
+ return selected_tensor
+
+ original_labels = self._labels
+ updated_labels = copy(original_labels)
+
+ # Ensure the index is iterable
+ if not isinstance(index, tuple):
+ index = [index]
+
+ # Update labels based on the index
+ offset = 0
+ removed = 0
+ for dim, idx in enumerate(index):
+ if dim in original_labels:
+ if isinstance(idx, int):
+ # Compute the working dimension considering the removed
+ # dimensions due to int index on a non labled dimension
+ dim_ = dim - removed
+ selected_tensor = selected_tensor.unsqueeze(dim_)
+ if idx != slice(None):
+ # Update the labels for the selected dimension
+ updated_labels[offset] = {
+ "dof": self._update_single_label(idx, dim),
+ "name": original_labels[dim]["name"],
+ }
+ else:
+ # Adjust label keys if dimension is reduced (case of integer
+ # index on a non-labeled dimension)
+ if isinstance(idx, int):
+ updated_labels = {
+ key - 1 if key > dim else key: value
+ for key, value in updated_labels.items()
+ }
+ removed += 1
+ continue
+ offset += 1
+
+ # Update the selected tensor's labels
+ selected_tensor._labels = updated_labels
+ return selected_tensor
+
+ def sort_labels(self, dim=None):
+ """
+ Sort the labels along the specified dimension and apply. It applies the
+ same sorting to the tensor part of the instance.
+
+ :param int dim: The dimension along which to sort the labels.
+ If ``None``, the last dimension is used.
+ :return: A new tensor with sorted labels along the specified dimension.
+ :rtype: LabelTensor
+ """
+
+ def arg_sort(lst):
+ return sorted(range(len(lst)), key=lambda x: lst[x])
+
+ if dim is None:
+ dim = self.ndim - 1
+ if self.shape[dim] == 1:
+ return self
+ labels = self.stored_labels[dim]["dof"]
+ sorted_index = arg_sort(labels)
+ # Define an indexer to sort the tensor along the specified dimension
+ indexer = [slice(None)] * self.ndim
+ # Assigned the sorted index to the specified dimension
+ indexer[dim] = sorted_index
+ return self[tuple(indexer)]
+
+ def __deepcopy__(self, memo):
+ """
+ Creates a deep copy of the object. For more details, see
+ :meth:`copy.deepcopy`.
+
+ :param memo: LabelTensor object to be copied.
+ :type memo: LabelTensor
+ :return: A deep copy of the original LabelTensor object.
+ :rtype: LabelTensor
+ """
+
+ cls = self.__class__
+ result = cls(deepcopy(self.tensor), deepcopy(self.stored_labels))
+ return result
+
+ def permute(self, *dims):
+ """
+ Permutes the dimensions of the tensor and the associated labels
+ accordingly. For more details, see :meth:`torch.Tensor.permute`.
+
+ :param dims: The dimensions to permute the tensor to.
+ :type dims: tuple[int] | list[int]
+ :return: A new object with permuted dimensions and reordered labels.
+ :rtype: LabelTensor
+ """
+ # Call the base class permute method
+ tensor = super().permute(*dims)
+
+ # Update lables
+ labels = self._labels
+ keys_list = list(*dims)
+ labels = {keys_list.index(k): v for k, v in labels.items()}
+
+ # Assign labels to the new tensor
+ tensor._labels = labels
+ return tensor
+
+ def detach(self):
+ """
+ Detaches the tensor from the computation graph and retains the stored
+ labels. For more details, see :meth:`torch.Tensor.detach`.
+
+ :return: A new tensor detached from the computation graph.
+ :rtype: LabelTensor
+ """
+
+ lt = super().detach()
+
+ # Copy the labels to the new tensor only if present
+ if hasattr(self, "_labels"):
+ lt._labels = self.stored_labels
+ return lt
+
+ @staticmethod
+ def summation(tensors):
+ """
+ Computes the summation of a list of
+ :class:`~pina.label_tensor.LabelTensor` instances.
+
+
+ :param list[LabelTensor] tensors: A list of tensors to sum. All
+ tensors must have the same shape and labels.
+ :return: A new `LabelTensor` containing the element-wise sum of the
+ input tensors.
+ :rtype: LabelTensor
+
+ :raises ValueError: If the input `tensors` list is empty.
+ :raises RuntimeError: If the tensors have different shapes and/or
+ mismatched labels.
+ """
+
+ if not tensors:
+ raise ValueError("The tensors list must not be empty.")
+
+ if len(tensors) == 1:
+ return tensors[0]
+
+ # Initialize result tensor and labels
+ data = torch.zeros_like(tensors[0].tensor).to(tensors[0].device)
+ last_dim_labels = []
+
+ # Accumulate tensors
+ for tensor in tensors:
+ data += tensor.tensor
+ last_dim_labels.append(tensor.labels)
+
+ # Construct last dimension labels
+ last_dim_labels = ["+".join(items) for items in zip(*last_dim_labels)]
+
+ # Update the labels for the resulting tensor
+ labels = {k: copy(v) for k, v in tensors[0].stored_labels.items()}
+ labels[tensors[0].ndim - 1] = {
+ "dof": last_dim_labels,
+ "name": tensors[0].name,
+ }
+
+ return LabelTensor(data, labels)
+
+ def reshape(self, *shape):
+ """
+ Override the reshape method to update the labels of the tensor.
+ For more details, see :meth:`torch.Tensor.reshape`.
+
+ :param tuple of int shape: The new shape of the tensor.
+ :return: A new :class:`~pina.label_tensor.LabelTensor` instance with the
+ updated shape and labels.
+ :rtype: LabelTensor
+ """
+
+ # As for now the reshape method is used only in the context of the
+ # dataset, the labels are not
+ tensor = super().reshape(*shape)
+ if not hasattr(self, "_labels") or shape != (-1, *self.shape[2:]):
+ return tensor
+ tensor.labels = self.labels
+ return tensor
diff --git a/pina/_src/core/operator.py b/pina/_src/core/operator.py
new file mode 100644
index 000000000..8ed28c3a6
--- /dev/null
+++ b/pina/_src/core/operator.py
@@ -0,0 +1,482 @@
+"""Module for vectorized differential operators implementation.
+
+Differential operators are used to define differential problems and are
+implemented to run efficiently on various accelerators, including CPU, GPU, TPU,
+and MPS.
+
+Each differential operator takes the following inputs:
+- A tensor on which the operator is applied.
+- A tensor with respect to which the operator is computed.
+- The names of the output variables for which the operator is evaluated.
+- The names of the variables with respect to which the operator is computed.
+
+Each differential operator has its fast version, which performs no internal
+checks on input and output tensors. For these methods, the user is always
+required to specify both ``components`` and ``d`` as lists of strings.
+"""
+
+import torch
+from pina._src.core.label_tensor import LabelTensor
+
+
+def _check_values(output_, input_, components, d):
+ """
+ Perform checks on arguments of differential operators.
+
+ :param LabelTensor output_: The output tensor on which the operator is
+ computed.
+ :param LabelTensor input_: The input tensor with respect to which the
+ operator is computed.
+ :param components: The names of the output variables for which to compute
+ the operator. It must be a subset of the output labels.
+ If ``None``, all output variables are considered. Default is ``None``.
+ :type components: str | list[str]
+ :param d: The names of the input variables with respect to which the
+ operator is computed. It must be a subset of the input labels.
+ If ``None``, all input variables are considered. Default is ``None``.
+ :type d: str | list[str]
+ :raises TypeError: If the input tensor is not a LabelTensor.
+ :raises TypeError: If the output tensor is not a LabelTensor.
+ :raises RuntimeError: If derivative labels are missing from the ``input_``.
+ :raises RuntimeError: If component labels are missing from the ``output_``.
+ :return: The components and d lists.
+ :rtype: tuple[list[str], list[str]]
+ """
+ # Check if the input is a LabelTensor
+ if not isinstance(input_, LabelTensor):
+ raise TypeError("Input must be a LabelTensor.")
+
+ # Check if the output is a LabelTensor
+ if not isinstance(output_, LabelTensor):
+ raise TypeError("Output must be a LabelTensor.")
+
+ # If no labels are provided, use all labels
+ d = d or input_.labels
+ components = components or output_.labels
+
+ # Convert to list if not already
+ d = d if isinstance(d, list) else [d]
+ components = components if isinstance(components, list) else [components]
+
+ # Check if all labels are present in the input tensor
+ if not all(di in input_.labels for di in d):
+ raise RuntimeError("Derivative labels missing from input tensor.")
+
+ # Check if all labels are present in the output tensor
+ if not all(c in output_.labels for c in components):
+ raise RuntimeError("Component label missing from output tensor.")
+
+ return components, d
+
+
+def _scalar_grad(output_, input_, d):
+ """
+ Compute the gradient of a scalar-valued ``output_``.
+
+ :param LabelTensor output_: The output tensor on which the gradient is
+ computed. It must be a column tensor.
+ :param LabelTensor input_: The input tensor with respect to which the
+ gradient is computed.
+ :param list[str] d: The names of the input variables with respect to
+ which the gradient is computed. It must be a subset of the input
+ labels. If ``None``, all input variables are considered.
+ :return: The computed gradient tensor.
+ :rtype: LabelTensor
+ """
+ grad_out = torch.autograd.grad(
+ outputs=output_,
+ inputs=input_,
+ grad_outputs=torch.ones_like(output_),
+ create_graph=True,
+ retain_graph=True,
+ allow_unused=True,
+ )[0]
+
+ return grad_out[..., [input_.labels.index(i) for i in d]]
+
+
+def _scalar_laplacian(output_, input_, d):
+ """
+ Compute the laplacian of a scalar-valued ``output_``.
+
+ :param LabelTensor output_: The output tensor on which the laplacian is
+ computed. It must be a column tensor.
+ :param LabelTensor input_: The input tensor with respect to which the
+ laplacian is computed.
+ :param list[str] d: The names of the input variables with respect to
+ which the laplacian is computed. It must be a subset of the input
+ labels. If ``None``, all input variables are considered.
+ :return: The computed laplacian tensor.
+ :rtype: LabelTensor
+ """
+ first_grad = fast_grad(
+ output_=output_, input_=input_, components=output_.labels, d=d
+ )
+ second_grad = fast_grad(
+ output_=first_grad, input_=input_, components=first_grad.labels, d=d
+ )
+ labels_to_extract = [f"d{c}d{d_}" for c, d_ in zip(first_grad.labels, d)]
+ return torch.sum(
+ second_grad.extract(labels_to_extract), dim=-1, keepdim=True
+ )
+
+
+def fast_grad(output_, input_, components, d):
+ """
+ Compute the gradient of the ``output_`` with respect to the ``input``.
+
+ Unlike ``grad``, this function performs no internal checks on input and
+ output tensors. The user is required to specify both ``components`` and
+ ``d`` as lists of strings. It is designed to enhance computation speed.
+
+ This operator supports both vector-valued and scalar-valued functions with
+ one or multiple input coordinates.
+
+ :param LabelTensor output_: The output tensor on which the gradient is
+ computed.
+ :param LabelTensor input_: The input tensor with respect to which the
+ gradient is computed.
+ :param list[str] components: The names of the output variables for which to
+ compute the gradient. It must be a subset of the output labels.
+ :param list[str] d: The names of the input variables with respect to which
+ the gradient is computed. It must be a subset of the input labels.
+ :return: The computed gradient tensor.
+ :rtype: LabelTensor
+ """
+ # Scalar gradient
+ if output_.shape[-1] == 1:
+ return LabelTensor(
+ _scalar_grad(output_=output_, input_=input_, d=d),
+ labels=[f"d{output_.labels[0]}d{i}" for i in d],
+ )
+
+ # Vector gradient
+ grads = torch.cat(
+ [
+ _scalar_grad(output_=output_.extract(c), input_=input_, d=d)
+ for c in components
+ ],
+ dim=-1,
+ )
+
+ return LabelTensor(
+ grads, labels=[f"d{c}d{i}" for c in components for i in d]
+ )
+
+
+def fast_div(output_, input_, components, d):
+ """
+ Compute the divergence of the ``output_`` with respect to ``input``.
+
+ Unlike ``div``, this function performs no internal checks on input and
+ output tensors. The user is required to specify both ``components`` and
+ ``d`` as lists of strings. It is designed to enhance computation speed.
+
+ This operator supports vector-valued functions with multiple input
+ coordinates.
+
+ :param LabelTensor output_: The output tensor on which the divergence is
+ computed.
+ :param LabelTensor input_: The input tensor with respect to which the
+ divergence is computed.
+ :param list[str] components: The names of the output variables for which to
+ compute the divergence. It must be a subset of the output labels.
+ :param list[str] d: The names of the input variables with respect to which
+ the divergence is computed. It must be a subset of the input labels.
+ :rtype: LabelTensor
+ """
+ grad_out = fast_grad(
+ output_=output_, input_=input_, components=components, d=d
+ )
+ tensors_to_sum = [
+ grad_out.extract(f"d{c}d{d_}") for c, d_ in zip(components, d)
+ ]
+
+ return LabelTensor.summation(tensors_to_sum)
+
+
+def fast_laplacian(output_, input_, components, d, method="std"):
+ """
+ Compute the laplacian of the ``output_`` with respect to ``input``.
+
+ Unlike ``laplacian``, this function performs no internal checks on input and
+ output tensors. The user is required to specify both ``components`` and
+ ``d`` as lists of strings. It is designed to enhance computation speed.
+
+ This operator supports both vector-valued and scalar-valued functions with
+ one or multiple input coordinates.
+
+ :param LabelTensor output_: The output tensor on which the laplacian is
+ computed.
+ :param LabelTensor input_: The input tensor with respect to which the
+ laplacian is computed.
+ :param list[str] components: The names of the output variables for which to
+ compute the laplacian. It must be a subset of the output labels.
+ :param list[str] d: The names of the input variables with respect to which
+ the laplacian is computed. It must be a subset of the input labels.
+ :param str method: The method used to compute the Laplacian. Available
+ methods are ``std`` and ``divgrad``. The ``std`` method computes the
+ trace of the Hessian matrix, while the ``divgrad`` method computes the
+ divergence of the gradient. Default is ``std``.
+ :return: The computed laplacian tensor.
+ :rtype: LabelTensor
+ :raises ValueError: If the passed method is neither ``std`` nor ``divgrad``.
+ """
+ # Scalar laplacian
+ if output_.shape[-1] == 1:
+ return LabelTensor(
+ _scalar_laplacian(output_=output_, input_=input_, d=d),
+ labels=[f"dd{c}" for c in components],
+ )
+
+ # Initialize the result tensor and its labels
+ labels = [f"dd{c}" for c in components]
+ result = torch.empty(
+ input_.shape[0], len(components), device=output_.device
+ )
+
+ # Vector laplacian
+ if method == "std":
+ result = torch.cat(
+ [
+ _scalar_laplacian(
+ output_=output_.extract(c), input_=input_, d=d
+ )
+ for c in components
+ ],
+ dim=-1,
+ )
+
+ elif method == "divgrad":
+ grads = fast_grad(
+ output_=output_, input_=input_, components=components, d=d
+ )
+ result = torch.cat(
+ [
+ fast_div(
+ output_=grads,
+ input_=input_,
+ components=[f"d{c}d{i}" for i in d],
+ d=d,
+ )
+ for c in components
+ ],
+ dim=-1,
+ )
+
+ else:
+ raise ValueError(
+ "Invalid method. Available methods are ``std`` and ``divgrad``."
+ )
+
+ return LabelTensor(result, labels=labels)
+
+
+def fast_advection(output_, input_, velocity_field, components, d):
+ """
+ Perform the advection operation on the ``output_`` with respect to the
+ ``input``. This operator supports vector-valued functions with multiple
+ input coordinates.
+
+ Unlike ``advection``, this function performs no internal checks on input and
+ output tensors. The user is required to specify both ``components`` and
+ ``d`` as lists of strings. It is designed to enhance computation speed.
+
+ :param LabelTensor output_: The output tensor on which the advection is
+ computed. It includes both the velocity and the quantity to be advected.
+ :param LabelTensor input_: the input tensor with respect to which advection
+ is computed.
+ :param list[str] velocity_field: The name of the output variables used as
+ velocity field. It must be chosen among the output labels.
+ :param list[str] components: The names of the output variables for which to
+ compute the advection. It must be a subset of the output labels.
+ :param list[str] d: The names of the input variables with respect to which
+ the advection is computed. It must be a subset of the input labels.
+ :return: The computed advection tensor.
+ :rtype: LabelTensor
+ """
+ # Add a dimension to the velocity field for following operations
+ velocity = output_.extract(velocity_field).unsqueeze(-1)
+
+ # Compute the gradient
+ grads = fast_grad(
+ output_=output_, input_=input_, components=components, d=d
+ )
+
+ # Reshape into [..., len(filter_components), len(d)]
+ tmp = grads.reshape(*output_.shape[:-1], len(components), len(d))
+
+ # Transpose to [..., len(d), len(filter_components)]
+ tmp = tmp.transpose(-1, -2)
+
+ adv = (tmp * velocity).sum(dim=tmp.tensor.ndim - 2)
+ return LabelTensor(adv, labels=[f"adv_{c}" for c in components])
+
+
+def grad(output_, input_, components=None, d=None):
+ """
+ Compute the gradient of the ``output_`` with respect to the ``input``.
+
+ This operator supports both vector-valued and scalar-valued functions with
+ one or multiple input coordinates.
+
+ :param LabelTensor output_: The output tensor on which the gradient is
+ computed.
+ :param LabelTensor input_: The input tensor with respect to which the
+ gradient is computed.
+ :param components: The names of the output variables for which to compute
+ the gradient. It must be a subset of the output labels.
+ If ``None``, all output variables are considered. Default is ``None``.
+ :type components: str | list[str]
+ :param d: The names of the input variables with respect to which the
+ gradient is computed. It must be a subset of the input labels.
+ If ``None``, all input variables are considered. Default is ``None``.
+ :type d: str | list[str]
+ :raises TypeError: If the input tensor is not a LabelTensor.
+ :raises TypeError: If the output tensor is not a LabelTensor.
+ :raises RuntimeError: If derivative labels are missing from the ``input_``.
+ :raises RuntimeError: If component labels are missing from the ``output_``.
+ :return: The computed gradient tensor.
+ :rtype: LabelTensor
+ """
+ components, d = _check_values(
+ output_=output_, input_=input_, components=components, d=d
+ )
+ return fast_grad(output_=output_, input_=input_, components=components, d=d)
+
+
+def div(output_, input_, components=None, d=None):
+ """
+ Compute the divergence of the ``output_`` with respect to ``input``.
+
+ This operator supports vector-valued functions with multiple input
+ coordinates.
+
+ :param LabelTensor output_: The output tensor on which the divergence is
+ computed.
+ :param LabelTensor input_: The input tensor with respect to which the
+ divergence is computed.
+ :param components: The names of the output variables for which to compute
+ the divergence. It must be a subset of the output labels.
+ If ``None``, all output variables are considered. Default is ``None``.
+ :type components: str | list[str]
+ :param d: The names of the input variables with respect to which the
+ divergence is computed. It must be a subset of the input labels.
+ If ``None``, all input variables are considered. Default is ``None``.
+ :type components: str | list[str]
+ :raises TypeError: If the input tensor is not a LabelTensor.
+ :raises TypeError: If the output tensor is not a LabelTensor.
+ :raises ValueError: If the length of ``components`` and ``d`` do not match.
+ :return: The computed divergence tensor.
+ :rtype: LabelTensor
+ """
+ components, d = _check_values(
+ output_=output_, input_=input_, components=components, d=d
+ )
+
+ # Components and d must be of the same length
+ if len(components) != len(d):
+ raise ValueError(
+ "Divergence requires components and d to be of the same length."
+ )
+
+ return fast_div(output_=output_, input_=input_, components=components, d=d)
+
+
+def laplacian(output_, input_, components=None, d=None, method="std"):
+ """
+ Compute the laplacian of the ``output_`` with respect to ``input``.
+
+ This operator supports both vector-valued and scalar-valued functions with
+ one or multiple input coordinates.
+
+ :param LabelTensor output_: The output tensor on which the laplacian is
+ computed.
+ :param LabelTensor input_: The input tensor with respect to which the
+ laplacian is computed.
+ :param components: The names of the output variables for which to
+ compute the laplacian. It must be a subset of the output labels.
+ If ``None``, all output variables are considered. Default is ``None``.
+ :type components: str | list[str]
+ :param d: The names of the input variables with respect to which
+ the laplacian is computed. It must be a subset of the input labels.
+ If ``None``, all input variables are considered. Default is ``None``.
+ :type d: str | list[str]
+ :param str method: The method used to compute the Laplacian. Available
+ methods are ``std`` and ``divgrad``. The ``std`` method computes the
+ trace of the Hessian matrix, while the ``divgrad`` method computes the
+ divergence of the gradient. Default is ``std``.
+ :raises TypeError: If the input tensor is not a LabelTensor.
+ :raises TypeError: If the output tensor is not a LabelTensor.
+ :raises ValueError: If the passed method is neither ``std`` nor ``divgrad``.
+ :return: The computed laplacian tensor.
+ :rtype: LabelTensor
+ """
+ components, d = _check_values(
+ output_=output_, input_=input_, components=components, d=d
+ )
+
+ return fast_laplacian(
+ output_=output_,
+ input_=input_,
+ components=components,
+ d=d,
+ method=method,
+ )
+
+
+def advection(output_, input_, velocity_field, components=None, d=None):
+ """
+ Perform the advection operation on the ``output_`` with respect to the
+ ``input``. This operator supports vector-valued functions with multiple
+ input coordinates.
+
+ :param LabelTensor output_: The output tensor on which the advection is
+ computed. It includes both the velocity and the quantity to be advected.
+ :param LabelTensor input_: the input tensor with respect to which advection
+ is computed.
+ :param velocity_field: The name of the output variables used as velocity
+ field. It must be chosen among the output labels.
+ :type velocity_field: str | list[str]
+ :param components: The names of the output variables for which to compute
+ the advection. It must be a subset of the output labels.
+ If ``None``, all output variables are considered. Default is ``None``.
+ :type components: str | list[str]
+ :param d: The names of the input variables with respect to which the
+ advection is computed. It must be a subset of the input labels.
+ If ``None``, all input variables are considered. Default is ``None``.
+ :type d: str | list[str]
+ :raises TypeError: If the input tensor is not a LabelTensor.
+ :raises TypeError: If the output tensor is not a LabelTensor.
+ :raises RuntimeError: If the velocity field is not a subset of the output
+ labels.
+ :raises RuntimeError: If the dimensionality of the velocity field does not
+ match that of the input tensor.
+ :return: The computed advection tensor.
+ :rtype: LabelTensor
+ """
+ components, d = _check_values(
+ output_=output_, input_=input_, components=components, d=d
+ )
+
+ # Map velocity_field to a list if it is a string
+ if isinstance(velocity_field, str):
+ velocity_field = [velocity_field]
+
+ # Check if all the velocity_field labels are present in the output labels
+ if not all(vi in output_.labels for vi in velocity_field):
+ raise RuntimeError("Velocity labels missing from output tensor.")
+
+ # Check if the velocity has the same dimensionality as the input tensor
+ if len(velocity_field) != len(d):
+ raise RuntimeError(
+ "Velocity dimensionality does not match input dimensionality."
+ )
+
+ return fast_advection(
+ output_=output_,
+ input_=input_,
+ velocity_field=velocity_field,
+ components=components,
+ d=d,
+ )
diff --git a/pina/_src/core/trainer.py b/pina/_src/core/trainer.py
new file mode 100644
index 000000000..0b89ab168
--- /dev/null
+++ b/pina/_src/core/trainer.py
@@ -0,0 +1,289 @@
+"""Trainer utilities built on top of the PyTorch Lightning Trainer class."""
+
+import warnings
+import torch
+import lightning
+from pina._src.solver.mixin.physics_informed_mixin import PhysicsInformedMixin
+from pina._src.solver.base_solver import BaseSolver
+from pina._src.data.data_module import DataModule
+from pina._src.core.utils import (
+ check_consistency,
+ custom_warning_format,
+ check_positive_integer,
+)
+
+# Set custom warning format and filter warnings
+warnings.formatwarning = custom_warning_format
+warnings.filterwarnings("always", category=UserWarning)
+
+
+class Trainer(lightning.pytorch.Trainer):
+ """
+ PINA-specific extension of :class:`lightning.pytorch.Trainer`.
+
+ The trainer configures solver execution, dataset splitting, batching,
+ logging, device placement for unknown parameters, and gradient tracking
+ requirements for physics-informed solvers.
+ """
+
+ # Available batching modes
+ _AVAIL_BATCHING_MODES = {
+ "common_batch_size",
+ "proportional",
+ "separate_conditions",
+ }
+
+ def __init__(
+ self,
+ solver,
+ batch_size=None,
+ train_size=1.0,
+ test_size=0.0,
+ val_size=0.0,
+ batching_mode="common_batch_size",
+ automatic_batching=False,
+ num_workers=0,
+ pin_memory=False,
+ shuffle=True,
+ **kwargs,
+ ):
+ """
+ Initialization of the :class:`Trainer` class.
+
+ :param BaseSolver solver: The solver used to train, validate, and test
+ the associated problem.
+ :param int batch_size: The number of samples per batch. If ``None``, the
+ entire dataset is processed as a single batch. Default is ``None``.
+ :param float train_size: The fraction of samples assigned to the
+ training split. Must belong to the interval ``[0, 1]``.
+ Default is ``1.0``.
+ :param float val_size: The fraction of samples assigned to the
+ validation split. Must belong to the interval ``[0, 1]``.
+ Default is ``0.0``.
+ :param float test_size: The fraction of samples assigned to the test
+ split. Must belong to the interval ``[0, 1]``. Default is ``0.0``.
+ :param str batching_mode: The strategy used to aggregate batches across
+ dataloaders. Available options are ``"common_batch_size"`` for
+ uniform batch sizes across conditions, ``"proportional"`` for batch
+ sizes proportional to dataset sizes, and ``"separate_conditions"``
+ for iterating through each condition separately.
+ Default is ``"common_batch_size"``.
+ :param bool automatic_batching: Whether PyTorch automatic batching
+ should be enabled. If ``True``, dataset elements are retrieved
+ individually and collated into batches by the dataloader.
+ If ``False``, entire subsets are retrieved directly from the
+ condition object. Default is ``False``.
+ :param int num_workers: The number of worker processes used by
+ dataloaders. Default is ``0`` for sequential loading.
+ :param bool pin_memory: Whether pinned memory should be enabled during
+ data loading. Default is ``False``.
+ :param bool shuffle: Whether condition samples should be shuffled before
+ splitting. Default is ``True``.
+ :param dict kwargs: Additional keyword arguments forwarded to the
+ Lightning trainer.
+ :raises ValueError: If ``solver`` is not a PINA solver.
+ :raises ValueError: If ``train_size``, ``val_size``, or ``test_size`` is
+ not a float in the interval ``[0, 1]``.
+ :raises ValueError: If the sum of ``train_size``, ``val_size``, and
+ ``test_size`` is not equal to 1.
+ :raises ValueError: If ``automatic_batching``, ``pin_memory``, or
+ ``shuffle`` is not a boolean.
+ :raises AssertionError: If ``num_workers`` is a negative integer.
+ :raises ValueError: If ``batch_size``, when provided, is not a positive
+ integer.
+ :raises ValueError: If ``batching_mode`` is not one of the available
+ options.
+ :raises UserWarning: If the provided ``batching_mode`` is incompatible
+ with the ``batch_size``.
+ :raises RuntimeError: If any domain in the problem has not been
+ discretised.
+ """
+ # Backward compatibility: compile has been removed
+ if "compile" in kwargs:
+ warnings.warn(
+ "`compile` is deprecated and no longer used. Compilation is "
+ "now disabled and the argument will be ignored.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ kwargs.pop("compile")
+
+ # Check consistency
+ check_consistency(solver, BaseSolver)
+ check_consistency(train_size, float)
+ check_consistency(test_size, float)
+ check_consistency(val_size, float)
+ check_consistency(automatic_batching, bool)
+ check_consistency(pin_memory, bool)
+ check_consistency(shuffle, bool)
+ check_positive_integer(num_workers, strict=False)
+ if batch_size is not None:
+ check_positive_integer(batch_size, strict=True)
+
+ # Check that train_size, test_size and val_size sum to 1
+ total = train_size + val_size + test_size
+ if not torch.isclose(torch.tensor(total), torch.tensor(1.0)):
+ raise ValueError(
+ "`train_size`, `val_size`, and `test_size` must sum to 1."
+ )
+
+ # Check consistency
+ if batching_mode not in self._AVAIL_BATCHING_MODES:
+ raise ValueError(
+ f"Invalid batching mode '{batching_mode}'. "
+ f"Expected one of: {sorted(self._AVAIL_BATCHING_MODES)}."
+ )
+
+ # Set inference mode to false when usiing physics-informed mixin
+ if isinstance(solver, PhysicsInformedMixin):
+ kwargs["inference_mode"] = False
+
+ # Set log_every_n_steps to 0 if batch_size is None, otherwise default
+ kwargs["log_every_n_steps"] = (
+ 0 if batch_size is None else kwargs.get("log_every_n_steps", 50)
+ )
+
+ # Set default value for enable_progress_bar to True if not provided
+ kwargs.setdefault("enable_progress_bar", True)
+
+ # Initialize the parent class with the provided keyword arguments
+ super().__init__(**kwargs)
+
+ # Raise warning if batch size and batching mode are incompatible
+ if batch_size is None and batching_mode != "common_batch_size":
+ warnings.warn(
+ f"Batching mode '{batching_mode}' is ignored when the batch "
+ "size is None. Setting batching_mode to 'common_batch_size'.",
+ UserWarning,
+ )
+
+ # Set batching mode to common_batch_size if incompatible
+ batching_mode = "common_batch_size"
+
+ # Raise warning if batch size and batching mode are incompatible
+ if (
+ batch_size is not None
+ and batching_mode == "proportional"
+ and batch_size <= len(solver.problem.conditions)
+ ):
+ warnings.warn(
+ "Batching mode 'proportional' requires the batch size to be "
+ "larger than the number of conditions. Setting batching_mode "
+ "to 'common_batch_size'.",
+ UserWarning,
+ )
+
+ # Set batching mode to common_batch_size if incompatible
+ batching_mode = "common_batch_size"
+
+ # Initialize the class attributes
+ self.solver = solver
+ self.batch_size = batch_size
+
+ # Move the unknown parameters to the correct device
+ self._move_to_device()
+
+ # Check that all domains are discretised, otherwise raise an error
+ if not self.solver.problem.are_all_domains_discretised:
+
+ # Get the list of sampled domains from the problem
+ sampled_domains = self.solver.problem.discretised_domains
+
+ # Create a status message for each domain
+ status = "\n".join(
+ f" - Domain '{name}': "
+ f"{'sampled' if name in sampled_domains else 'not sampled'}"
+ for name in self.solver.problem.domains
+ )
+
+ # Raise an error with the status of each domain
+ raise RuntimeError(
+ "Cannot create the Trainer because some domains have not been "
+ f"sampled. Domain status:\n{status}"
+ )
+
+ # Create the data module
+ self.data_module = DataModule(
+ problem=self.solver.problem,
+ train_size=train_size,
+ test_size=test_size,
+ val_size=val_size,
+ batch_size=self.batch_size,
+ batching_mode=batching_mode,
+ automatic_batching=automatic_batching,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ shuffle=shuffle,
+ )
+
+ # Set logging kwargs
+ self.logging_kwargs = {
+ "sync_dist": bool(
+ len(self._accelerator_connector._parallel_devices) > 1
+ ),
+ "on_step": bool(kwargs["log_every_n_steps"] > 0),
+ "prog_bar": bool(kwargs["enable_progress_bar"]),
+ "on_epoch": True,
+ }
+
+ def _move_to_device(self):
+ """
+ Move problem unknown parameters to the trainer device.
+
+ If the associated problem defines ``unknown_parameters``, each parameter
+ is moved to the first device configured by the Lightning accelerator
+ connector.
+ """
+ # Get the device from the accelerator connector
+ device = self._accelerator_connector._parallel_devices[0]
+
+ # Get the problem instance from the solver
+ problem = self.solver.problem
+
+ # Move the unknown parameters to the correct device if they exist
+ if hasattr(problem, "unknown_parameters"):
+ for key in problem.unknown_parameters:
+ problem.unknown_parameters[key] = torch.nn.Parameter(
+ problem.unknown_parameters[key].data.to(device)
+ )
+
+ def train(self, **kwargs):
+ """
+ Fit the solver using the trainer data module.
+
+ :param dict kwargs: Additional keyword arguments forwarded to the
+ Lightning trainer ``fit`` method.
+ :return: Result returned by Lightning's ``fit`` method.
+ :rtype: Any
+ """
+ return super().fit(self.solver, datamodule=self.data_module, **kwargs)
+
+ def test(self, **kwargs):
+ """
+ Test the solver using the trainer data module.
+
+ :param dict kwargs: Additional keyword arguments forwarded to the
+ Lightning trainer ``test`` method.
+ :return: Result returned by Lightning's ``test`` method.
+ :rtype: Any
+ """
+ return super().test(self.solver, datamodule=self.data_module, **kwargs)
+
+ @property
+ def solver(self):
+ """
+ Return the solver attached to the trainer.
+
+ :return: The solver used by the trainer.
+ :rtype: BaseSolver
+ """
+ return self._solver
+
+ @solver.setter
+ def solver(self, solver):
+ """
+ Set the solver attached to the trainer.
+
+ :param BaseSolver solver: The solver instance to attach.
+ """
+ self._solver = solver
diff --git a/pina/_src/core/type_checker.py b/pina/_src/core/type_checker.py
new file mode 100644
index 000000000..e8c908ac9
--- /dev/null
+++ b/pina/_src/core/type_checker.py
@@ -0,0 +1,93 @@
+"""Module for enforcing type hints in Python functions."""
+
+import inspect
+import typing
+import logging
+
+
+def enforce_types(func):
+ """
+ Function decorator to enforce type hints at runtime.
+
+ This decorator checks the types of the arguments and of the return value of
+ the decorated function against the type hints specified in the function
+ signature. If the types do not match, a TypeError is raised.
+ Type checking is only performed when the logging level is set to `DEBUG`.
+
+ :param Callable func: The function to be decorated.
+ :return: The decorated function with enforced type hints.
+ :rtype: Callable
+
+ :Example:
+
+ >>> @enforce_types
+ def dummy_function(a: int, b: float) -> float:
+ ... return a+b
+
+ # This always works.
+ dummy_function(1, 2.0)
+
+ # This raises a TypeError for the second argument, if logging is set to
+ # `DEBUG`.
+ dummy_function(1, "Hello, world!")
+
+
+ >>> @enforce_types
+ def dummy_function2(a: int, right: bool) -> float:
+ ... if right:
+ ... return float(a)
+ ... else:
+ ... return "Hello, world!"
+
+ # This always works.
+ dummy_function2(1, right=True)
+
+ # This raises a TypeError for the return value if logging is set to
+ # `DEBUG`.
+ dummy_function2(1, right=False)
+ """
+
+ def wrapper(*args, **kwargs):
+ """
+ Wrapper function to enforce type hints.
+
+ :param tuple args: Positional arguments passed to the function.
+ :param dict kwargs: Keyword arguments passed to the function.
+ :raises TypeError: If the argument or return type does not match the
+ specified type hints.
+ :return: The result of the decorated function.
+ :rtype: Any
+ """
+ level = logging.getLevelName(logging.getLogger().getEffectiveLevel())
+
+ # Enforce type hints only in debug mode
+ if level != "DEBUG":
+ return func(*args, **kwargs)
+
+ # Get the type hints for the function arguments
+ hints = typing.get_type_hints(func)
+ sig = inspect.signature(func)
+ bound = sig.bind(*args, **kwargs)
+ bound.apply_defaults()
+
+ for arg_name, arg_value in bound.arguments.items():
+ expected_type = hints.get(arg_name)
+ if expected_type and not isinstance(arg_value, expected_type):
+ raise TypeError(
+ f"Argument '{arg_name}' must be {expected_type.__name__}, "
+ f"but got {type(arg_value).__name__}!"
+ )
+
+ # Get the type hints for the return values
+ return_type = hints.get("return")
+ result = func(*args, **kwargs)
+
+ if return_type and not isinstance(result, return_type):
+ raise TypeError(
+ f"Return value must be {return_type.__name__}, "
+ f"but got {type(result).__name__}!"
+ )
+
+ return result
+
+ return wrapper
diff --git a/pina/_src/core/utils.py b/pina/_src/core/utils.py
new file mode 100644
index 000000000..d0226ea83
--- /dev/null
+++ b/pina/_src/core/utils.py
@@ -0,0 +1,270 @@
+"""Module for utility functions."""
+
+import types
+from functools import reduce
+import torch
+
+from pina._src.core.label_tensor import LabelTensor
+
+
+# Codacy error unused parameters
+def custom_warning_format(
+ message, category, filename, lineno, file=None, line=None
+):
+ """
+ Custom warning formatting function.
+
+ :param str message: The warning message.
+ :param Warning category: The warning category.
+ :param str filename: The filename where the warning is raised.
+ :param int lineno: The line number where the warning is raised.
+ :param str file: The file object where the warning is raised.
+ Default is None.
+ :param int line: The line where the warning is raised.
+ :return: The formatted warning message.
+ :rtype: str
+ """
+ return f"{filename}: {category.__name__}: {message}\n"
+
+
+def check_consistency(object_, object_instance, subclass=False):
+ """
+ Check if an object maintains inheritance consistency.
+
+ This function checks whether a given object is an instance of a specified
+ class or, if ``subclass=True``, whether it is a subclass of the specified
+ class.
+
+ :param object: The object to check.
+ :type object: Iterable | Object
+ :param Object object_instance: The expected parent class.
+ :param bool subclass: If True, checks whether ``object_`` is a subclass
+ of ``object_instance`` instead of an instance. Default is ``False``.
+ :raises ValueError: If ``object_`` does not inherit from ``object_instance``
+ as expected.
+ """
+ if not isinstance(object_, (list, set, tuple)):
+ object_ = [object_]
+
+ for obj in object_:
+ is_class = isinstance(obj, type)
+ expected_type_name = (
+ object_instance.__name__
+ if isinstance(object_instance, type)
+ else str(object_instance)
+ )
+
+ if subclass:
+ if not is_class:
+ raise ValueError(
+ f"You passed {repr(obj)} "
+ f"(an instance of {type(obj).__name__}), "
+ f"but a {expected_type_name} class was expected. "
+ f"Please pass a {expected_type_name} class or a "
+ "derived one."
+ )
+ if not issubclass(obj, object_instance):
+ raise ValueError(
+ f"You passed {obj.__name__} class, but a "
+ f"{expected_type_name} class was expected. "
+ f"Please pass a {expected_type_name} class or a "
+ "derived one."
+ )
+ else:
+ if is_class:
+ raise ValueError(
+ f"You passed {obj.__name__} class, but a "
+ f"{expected_type_name} instance was expected. "
+ f"Please pass a {expected_type_name} instance."
+ )
+ if not isinstance(obj, object_instance):
+ raise ValueError(
+ f"You passed {repr(obj)} "
+ f"(an instance of {type(obj).__name__}), "
+ f"but a {expected_type_name} instance was expected. "
+ f"Please pass a {expected_type_name} instance."
+ )
+
+
+def labelize_forward(forward, input_variables, output_variables):
+ """
+ Decorator to enable or disable the use of
+ :class:`~pina.label_tensor.LabelTensor` during the forward pass.
+
+ :param Callable forward: The forward function of a :class:`torch.nn.Module`.
+ :param list[str] input_variables: The names of the input variables of a
+ :class:`~pina.problem.base_problem.BaseProblem`.
+ :param list[str] output_variables: The names of the output variables of a
+ :class:`~pina.problem.base_problem.BaseProblem`.
+ :return: The decorated forward function.
+ :rtype: Callable
+ """
+
+ def wrapper(x, *args, **kwargs):
+ """
+ Decorated forward function.
+
+ :param LabelTensor x: The labelized input of the forward pass of an
+ instance of :class:`torch.nn.Module`.
+ :param Iterable args: Additional positional arguments passed to
+ ``forward`` method.
+ :param dict kwargs: Additional keyword arguments passed to
+ ``forward`` method.
+ :return: The labelized output of the forward pass of an instance of
+ :class:`torch.nn.Module`.
+ :rtype: LabelTensor
+ """
+ x = x.extract(input_variables)
+ output = forward(x, *args, **kwargs)
+ # keep it like this, directly using LabelTensor(...) raises errors
+ # when compiling the code
+ output = output.as_subclass(LabelTensor)
+ output.labels = output_variables
+ return output
+
+ return wrapper
+
+
+def merge_tensors(tensors):
+ """
+ Merge a list of :class:`~pina.label_tensor.LabelTensor` instances into a
+ single :class:`~pina.label_tensor.LabelTensor` tensor, by applying
+ iteratively the cartesian product.
+
+ :param list[LabelTensor] tensors: The list of tensors to merge.
+ :raises ValueError: If the list of tensors is empty.
+ :return: The merged tensor.
+ :rtype: LabelTensor
+ """
+ if tensors:
+ return reduce(merge_two_tensors, tensors[1:], tensors[0])
+ raise ValueError("Expected at least one tensor")
+
+
+def merge_two_tensors(tensor1, tensor2):
+ """
+ Merge two :class:`~pina.label_tensor.LabelTensor` instances into a single
+ :class:`~pina.label_tensor.LabelTensor` tensor, by applying the cartesian
+ product.
+
+ :param LabelTensor tensor1: The first tensor to merge.
+ :param LabelTensor tensor2: The second tensor to merge.
+ :return: The merged tensor.
+ :rtype: LabelTensor
+ """
+ n1 = tensor1.shape[0]
+ n2 = tensor2.shape[0]
+
+ tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
+ tensor2 = LabelTensor(
+ tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels
+ )
+ return tensor1.append(tensor2)
+
+
+def torch_lhs(n, dim):
+ """
+ The Latin Hypercube Sampling torch routine, sampling in :math:`[0, 1)`$.
+
+ :param int n: The number of points to sample.
+ :param int dim: The number of dimensions of the sampling space.
+ :raises TypeError: If `n` or `dim` are not integers.
+ :raises ValueError: If `dim` is less than 1.
+ :return: The sampled points.
+ :rtype: torch.tensor
+ """
+
+ if not isinstance(n, int):
+ raise TypeError("number of point n must be int")
+
+ if not isinstance(dim, int):
+ raise TypeError("dim must be int")
+
+ if dim < 1:
+ raise ValueError("dim must be greater than one")
+
+ samples = torch.rand(size=(n, dim))
+
+ perms = torch.tile(torch.arange(1, n + 1), (dim, 1))
+
+ for row in range(dim):
+ idx_perm = torch.randperm(perms.shape[-1])
+ perms[row, :] = perms[row, idx_perm]
+
+ perms = perms.T
+
+ samples = (perms - samples) / n
+
+ return samples
+
+
+def is_function(f):
+ """
+ Check if the given object is a function or a lambda.
+
+ :param Object f: The object to be checked.
+ :return: ``True`` if ``f`` is a function, ``False`` otherwise.
+ :rtype: bool
+ """
+ return callable(f)
+
+
+def chebyshev_roots(n):
+ """
+ Compute the roots of the Chebyshev polynomial of degree ``n``.
+
+ :param int n: The number of roots to return.
+ :return: The roots of the Chebyshev polynomials.
+ :rtype: torch.Tensor
+ """
+ pi = torch.acos(torch.zeros(1)).item() * 2
+ k = torch.arange(n)
+ nodes = torch.sort(torch.cos(pi * (k + 0.5) / n))[0]
+ return nodes
+
+
+def check_positive_integer(value, strict=True):
+ """
+ Check if the value is a positive integer.
+
+ :param int value: The value to check.
+ :param bool strict: If True, the value must be strictly positive.
+ Default is True.
+ :raises AssertionError: If the value is not a positive integer.
+ """
+ if strict:
+ assert (
+ isinstance(value, int) and value > 0
+ ), f"Expected a strictly positive integer, got {value}."
+ else:
+ assert (
+ isinstance(value, int) and value >= 0
+ ), f"Expected a non-negative integer, got {value}."
+
+
+def in_range(value, range_vals, strict=True):
+ """
+ Check if a value is within a specified range.
+
+ :param int value: The integer value to check.
+ :param list[int] range_vals: A list of two integers representing the range
+ limits. The first element specifies the lower bound, and the second
+ specifies the upper bound.
+ :param bool strict: If True, the value must be strictly positive.
+ Default is True.
+ :return: True if the value satisfies the range condition, False otherwise.
+ :rtype: bool
+ """
+ # Validate inputs
+ check_consistency(value, (float, int))
+ check_consistency(range_vals, (float, int))
+ assert (
+ isinstance(range_vals, list) and len(range_vals) == 2
+ ), "range_vals must be a list of two integers [lower, upper]"
+ lower, upper = range_vals
+
+ # Check the range
+ if strict:
+ return lower < value < upper
+
+ return lower <= value <= upper
diff --git a/pina/_src/data/__init__.py b/pina/_src/data/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pina/_src/data/aggregator.py b/pina/_src/data/aggregator.py
new file mode 100644
index 000000000..d6e149a3f
--- /dev/null
+++ b/pina/_src/data/aggregator.py
@@ -0,0 +1,87 @@
+"""Utility class for aggregating multiple dataloaders into a single iterable."""
+
+
+class _Aggregator:
+ """
+ Aggregate multiple dataloaders into a unified iterable object.
+
+ The aggregator combines batches produced by multiple dataloaders according
+ to the selected batching strategy. It is primarily used to coordinate the
+ iteration of multiple training conditions within a single training loop.
+ """
+
+ def __init__(self, dataloaders, batching_mode):
+ """
+ Initialization of the :class:`_Aggregator` class.
+
+ :param dict[str, DataLoader] dataloaders: The mapping between condition
+ names and their corresponding dataloaders.
+ :param str batching_mode: The strategy used to aggregate batches across
+ dataloaders. Available options are ``"common_batch_size"`` for
+ uniform batch sizes across conditions, ``"proportional"`` for batch
+ sizes proportional to dataset sizes, and ``"separate_conditions"``
+ for iterating through each condition separately.
+ :raises NotImplementedError: If the selected batching mode is not yet
+ implemented.
+ """
+ # Raise not implemented error for separate_conditions mode
+ if batching_mode == "separate_conditions":
+ raise NotImplementedError(
+ "Batching mode 'separate_conditions' is not implemented yet."
+ )
+
+ # Initialize attributes
+ self.dataloaders = dataloaders
+ self.batching_mode = batching_mode
+
+ def __len__(self):
+ """
+ Return the length of the aggregated dataloader. The length is determined
+ by the number of iterations required to exhaust the dataloaders based on
+ the selected batching mode.
+
+ For ``"separate_conditions"``, the total number of iterations is the sum
+ of the lengths of all dataloaders. For all other batching modes, the
+ length corresponds to the maximum length among the aggregated
+ dataloaders.
+
+ :return: The length of the aggregated dataloader.
+ :rtype: int
+ """
+ # Separate conditions case
+ if self.batching_mode == "separate_conditions":
+ return sum(len(dl) for dl in self.dataloaders.values())
+
+ return max(len(dl) for dl in self.dataloaders.values())
+
+ def __iter__(self):
+ """
+ Iterate over the aggregated dataloaders.
+
+ At each iteration, a dictionary containing one batch per dataloader is
+ yielded. If a dataloader is exhausted before the others, its iterator is
+ restarted automatically to ensure continuous batch generation.
+
+ :yield: The dictionary mapping each condition name to its batch.
+ :rtype: Iterator[dict[str, Any]]
+ """
+ # Initialize iterators for each dataloader
+ iterators = {name: iter(dl) for name, dl in self.dataloaders.items()}
+
+ # Iterate until the maximum number of iterations is reached
+ for _ in range(len(self)):
+ batch = {}
+
+ # Generate a batch for each dataloader
+ for name, dataloader in self.dataloaders.items():
+
+ # Attempt to get the next batch from the dataloader's iterator
+ try:
+ batch[name] = next(iterators[name])
+
+ # Restart the iterator if it is exhausted
+ except StopIteration:
+ iterators[name] = iter(dataloader)
+ batch[name] = next(iterators[name])
+
+ yield batch
diff --git a/pina/_src/data/condition_subset.py b/pina/_src/data/condition_subset.py
new file mode 100644
index 000000000..068e833a2
--- /dev/null
+++ b/pina/_src/data/condition_subset.py
@@ -0,0 +1,101 @@
+"""Utilities for handling condition dataset subsets."""
+
+from torch_geometric.data import Batch
+from pina._src.core.graph import LabelBatch, Graph
+
+
+class _ConditionSubset:
+ """
+ Wrapper around a condition dataset restricted to a subset of indices.
+
+ The class behaves similarly to :class:`torch.utils.data.Subset` and supports
+ cyclic indexing together with optional automatic batching.
+ """
+
+ def __init__(self, condition, indices, automatic_batching):
+ """
+ Initialization of the :class:`_ConditionSubset` class.
+
+ :param BaseCondition condition: The underlying condition.
+ :param list[int] indices: The list of indices identifying the subset
+ samples.
+ :param bool automatic_batching: Whether dataset items should be returned
+ directly or as raw indices.
+ """
+ super().__init__()
+
+ # Initialize the class attributes
+ self.condition = condition
+ self.indices = indices
+ self.automatic_batching = automatic_batching
+
+ # Actual number of samples contained in the subset
+ self.dataset_length = len(self.indices)
+
+ # Effective iterable length used and modified during batching
+ self.iterable_length = self.dataset_length
+
+ def __len__(self):
+ """
+ Return the effective iterable length of the subset.
+
+ :return: The number of accessible elements in the subset.
+ :rtype: int
+ """
+ return self.iterable_length
+
+ def __getitem__(self, idx):
+ """
+ Retrieve an element from the subset.
+
+ If the requested index exceeds the actual dataset size, cyclic indexing
+ is applied through modulo wrapping. When automatic batching is disabled,
+ the raw dataset index is returned instead of the corresponding sample.
+
+ :param int idx: The position of the element inside the subset.
+ :return: The dataset sample or raw dataset index depending on the
+ batching configuration.
+ :rtype: dict | int
+ """
+ # Apply cyclic indexing if the requested index exceeds the subset length
+ if idx >= self.dataset_length:
+ idx = idx % self.dataset_length
+
+ # Fetch the corresponding dataset index from the list of indices
+ idx = self.indices[idx]
+
+ # Return the raw dataset index if automatic batching is disabled
+ if not self.automatic_batching:
+ return idx
+
+ return self.condition[idx]
+
+ def get_all_data(self):
+ """
+ Retrieve and aggregate all subset samples.
+
+ If the returned data contains a ``"data"`` field composed of graph
+ objects, the samples are merged into a single batched graph structure
+ using the appropriate batching implementation.
+
+ :return: The aggregated subset data.
+ :rtype: dict
+ """
+ # Fetch the data corresponding to the subset indices
+ data = self.condition[self.indices]
+
+ # Data as a list of graph objects merged into a single batched graph
+ if "data" in data and isinstance(data["data"], list):
+
+ # Define the batching function
+ batch_fn = (
+ LabelBatch.from_data_list
+ if isinstance(data["data"][0], Graph)
+ else Batch.from_data_list
+ )
+
+ # Merge the list of graph objects into a single batched graph
+ data["data"] = batch_fn(data["data"])
+ data = {"input": data["data"], "target": data["data"].y}
+
+ return data
diff --git a/pina/_src/data/creator.py b/pina/_src/data/creator.py
new file mode 100644
index 000000000..4a5e3207b
--- /dev/null
+++ b/pina/_src/data/creator.py
@@ -0,0 +1,240 @@
+"""Module for creating dataloaders for multiple conditions."""
+
+import torch
+from torch.utils.data.distributed import DistributedSampler
+
+
+class _Creator:
+ """
+ Utility class for creating data loaders associated with multiple conditions.
+
+ The class supports different batching strategies to adapt data loading
+ behavior to specific training requirements
+ """
+
+ def __init__(
+ self,
+ batching_mode,
+ batch_size,
+ shuffle,
+ automatic_batching,
+ num_workers,
+ pin_memory,
+ conditions,
+ ):
+ """
+ Initialization of the :class:`_Creator` class.
+
+ :param str batching_mode: The strategy used to aggregate batches across
+ data loaders. Available options are ``"common_batch_size"`` for
+ uniform batch sizes across conditions, ``"proportional"`` for batch
+ sizes proportional to dataset sizes, and ``"separate_conditions"``
+ for iterating through each condition separately.
+ :param int batch_size: Batch size configuration used by the selected
+ batching strategy. For ``"common_batch_size"``, the same batch size
+ is assigned to all conditions. For ``"proportional"``, this value
+ represents the total batch size distributed proportionally across
+ conditions. For ``"separate_conditions"``, this value is applied
+ independently to each condition and capped by the corresponding
+ dataset size.
+ :param bool shuffle: Whether samples should be shuffled during loading.
+ :param bool automatic_batching: Whether automatic batching should be
+ enabled in the data loaders.
+ :param int num_workers: The number of worker processes used for data
+ loading.
+ :param bool pin_memory: Whether data loaders should pin memory.
+ :param dict[str, BaseCondition] conditions: The mapping between
+ condition names and condition objects responsible for data loader
+ creation.
+ """
+ # Initialize attributes
+ self.batching_mode = batching_mode
+ self.batch_size = batch_size
+ self.shuffle = shuffle
+ self.automatic_batching = automatic_batching
+ self.num_workers = num_workers
+ self.pin_memory = pin_memory
+ self.conditions = conditions
+
+ def __call__(self, datasets):
+ """
+ Create data loaders for all provided datasets.
+
+ Batch sizes are computed according to the selected batching mode, and a
+ dedicated data loader is created for each condition.
+
+ :param dict[str, _ConditionSubset] datasets: The mapping between
+ condition names and datasets.
+ :return: The mapping between condition names and the corresponding
+ data loaders.
+ :rtype: dict[str, DataLoader]
+ """
+ # Compute batch sizes per condition based on batching_mode
+ batch_sizes = self._compute_batch_sizes(datasets)
+ dataloaders = {}
+
+ # If common_batch_size mode, ensure all datasets have the same length
+ if self.batching_mode == "common_batch_size":
+ iterable_length = max(len(dataset) for dataset in datasets.values())
+
+ # Iterate through datasets and create dataloaders
+ for name, dataset in datasets.items():
+
+ # If common_batch_size mode, set max_len for datasets
+ if (
+ self.batching_mode == "common_batch_size"
+ and dataset.dataset_length != batch_sizes[name]
+ ):
+ dataset.iterable_length = iterable_length
+
+ # Create dataloader for the current condition
+ dataloaders[name] = self.conditions[name].create_dataloader(
+ dataset=dataset,
+ batch_size=batch_sizes[name],
+ automatic_batching=self.automatic_batching,
+ sampler=self._define_sampler(dataset, self.shuffle),
+ num_workers=self.num_workers,
+ pin_memory=self.pin_memory,
+ )
+
+ return dataloaders
+
+ def _define_sampler(self, dataset, shuffle):
+ """
+ Define the sampling strategy for a dataset.
+
+ Distributed training uses :class:`DistributedSampler`, while
+ non-distributed execution uses either :class:`RandomSampler` or
+ :class:`SequentialSampler` depending on ``shuffle``.
+
+ :param _ConditionSubset dataset: The dataset associated with the
+ sampler.
+ :param bool shuffle: Whether samples should be shuffled during loading.
+ :return: The configured sampler instance.
+ :rtype: Sampler
+ """
+ # Distributed training case
+ if torch.distributed.is_initialized():
+ return DistributedSampler(dataset, shuffle=shuffle)
+
+ # Non-distributed training case - shuffle True
+ if shuffle:
+ return torch.utils.data.RandomSampler(dataset)
+
+ # Non-distributed training case - shuffle False
+ return torch.utils.data.SequentialSampler(dataset)
+
+ def _compute_batch_sizes(self, datasets):
+ """
+ Compute batch sizes for each dataset according to the selected batching
+ mode.
+
+ :param dict[str, _ConditionSubset] datasets: The mapping between
+ condition names and datasets.
+ :return: The mapping between condition names and computed batch sizes.
+ :rtype: dict[str, int]
+ """
+ # Common batch size mode
+ if self.batching_mode == "common_batch_size":
+
+ # Compute batch size
+ batch_size = (
+ max(dataset.dataset_length for dataset in datasets.values())
+ if self.batch_size is None
+ else self.batch_size
+ )
+
+ return {
+ name: min(batch_size, len(dataset))
+ for name, dataset in datasets.items()
+ }
+
+ # Proportional batch size mode
+ if self.batching_mode == "proportional":
+ return self._compute_proportional_batch_sizes(datasets)
+
+ # Separate conditions mode
+ return {
+ name: (
+ len(dataset)
+ if self.batch_size is None
+ else min(self.batch_size, len(dataset))
+ )
+ for name, dataset in datasets.items()
+ }
+
+ def _compute_proportional_batch_sizes(self, datasets):
+ """
+ Compute batch sizes proportionally to dataset sizes.
+
+ Each dataset receives a fraction of the total batch size proportional to
+ its number of samples, while ensuring that each dataset contributes at
+ least one sample.
+
+ :param dict[str, _ConditionSubset] datasets: The mapping between
+ condition names and datasets.
+ :return: The mapping between condition names and proportional batch
+ sizes.
+ :rtype: dict[str, int]
+ """
+ # Compute the sizes of each dataset
+ dataset_sizes = {
+ name: len(dataset) for name, dataset in datasets.items()
+ }
+
+ # Determine the total number of elements across all datasets
+ total_size = sum(dataset_sizes.values())
+
+ # Compute the batch sizes
+ batch_sizes = {
+ name: max(1, int(self.batch_size * (size / total_size)))
+ for name, size in dataset_sizes.items()
+ }
+
+ # Compute assigned batch size and difference with the total batch size
+ assigned_batch_size = sum(batch_sizes.values())
+ difference = self.batch_size - assigned_batch_size
+
+ # If difference > 0, distribute to datasets with more than 1 sample
+ if difference > 0:
+
+ # Sort datasets by size in descending order
+ sorted_datasets = sorted(
+ dataset_sizes,
+ key=lambda name: dataset_sizes[name],
+ reverse=True,
+ )
+
+ # Distribute to datasets with more than 1 sample
+ for name in sorted_datasets:
+
+ # Stop distribution when the difference is fully allocated
+ if difference == 0:
+ break
+
+ # Distribute to datasets with more than 1 sample
+ if dataset_sizes[name] > 1:
+ batch_sizes[name] += 1
+ difference -= 1
+
+ # If difference < 0, reduce from datasets with more than 1 sample
+ if difference < 0:
+
+ # Sort batches by size in descending order
+ sorted_batches = sorted(
+ batch_sizes, key=lambda name: batch_sizes[name], reverse=True
+ )
+
+ # Reduce from datasets with more than 1 sample
+ for name in sorted_batches:
+
+ # Stop reduction when the difference is fully allocated
+ if difference == 0:
+ break
+
+ # Reduce from datasets with more than 1 sample
+ if batch_sizes[name] > 1:
+ batch_sizes[name] -= 1
+ difference += 1
+
+ return batch_sizes
diff --git a/pina/_src/data/data_module.py b/pina/_src/data/data_module.py
new file mode 100644
index 000000000..c5d3804a5
--- /dev/null
+++ b/pina/_src/data/data_module.py
@@ -0,0 +1,278 @@
+"""
+Utilities for creating and managing datasets and dataloaders.
+
+This module defines a custom extension of the Lighting DataModule used to handle
+dataset splitting, batching, and dataloader creation for PINA conditions.
+"""
+
+import warnings
+import torch
+from lightning.pytorch import LightningDataModule
+from pina._src.data.condition_subset import _ConditionSubset
+from pina._src.data.aggregator import _Aggregator
+from pina._src.data.creator import _Creator
+
+
+class DataModule(LightningDataModule):
+ """
+ An extension of the Lightning data module for managing PINA condition
+ datasets.
+
+ The data module handles train/validation/test dataset splitting, condition
+ subset creation, dataloader construction, and batching coordination across
+ multiple conditions.
+
+ Dataset splitting is performed independently for each condition, and the
+ resulting subsets are wrapped into :class:`_ConditionSubset` objects.
+ Dataloaders are then created and aggregated according to the selected
+ batching strategy.
+ """
+
+ def __init__(
+ self,
+ problem,
+ train_size,
+ val_size,
+ test_size,
+ batch_size,
+ batching_mode,
+ automatic_batching,
+ shuffle,
+ num_workers,
+ pin_memory,
+ ):
+ """
+ Initialization of the :class:`DataModule` class.
+
+ :param BaseProblem problem: The problem containing the conditions and
+ sampled data used to construct datasets and dataloaders.
+ :param float train_size: The fraction of samples assigned to the
+ training split. Must belong to the interval ``[0, 1]``.
+ :param float val_size: The fraction of samples assigned to the
+ validation split. Must belong to the interval ``[0, 1]``.
+ :param float test_size: The fraction of samples assigned to the test
+ split. Must belong to the interval ``[0, 1]``.
+ :param int batch_size: The number of samples per batch. If ``None``, the
+ entire dataset is processed as a single batch.
+ :param str batching_mode: The strategy used to aggregate batches across
+ dataloaders. Available options are ``"common_batch_size"`` for
+ uniform batch sizes across conditions, ``"proportional"`` for batch
+ sizes proportional to dataset sizes, and ``"separate_conditions"``
+ for iterating through each condition separately.
+ :param bool automatic_batching: Whether PyTorch automatic batching
+ should be enabled. If ``True``, dataset elements are retrieved
+ individually and collated into batches by the dataloader.
+ If ``False``, entire subsets are retrieved directly from the
+ condition object.
+ :param bool shuffle: Whether condition samples should be shuffled before
+ splitting.
+ :param int num_workers: The number of worker processes used by
+ dataloaders.
+ :param bool pin_memory: Whether pinned memory should be enabled during
+ data loading.
+ :raises UserWarning: If ``num_workers`` is set to non-default value
+ while ``batch_size`` is None.
+ :raises UserWarning: If ``pin_memory`` is set to ``True`` while
+ ``batch_size`` is None.
+ """
+ super().__init__()
+
+ # Initialize the attributes -- consistency checked in trainer
+ self.problem = problem
+ self.batch_size = batch_size
+ self.batching_mode = batching_mode
+ self.automatic_batching = automatic_batching
+ self.shuffle = shuffle
+ self.num_workers = num_workers
+ self.pin_memory = pin_memory
+
+ # If batch size is None, num_workers has no effect
+ if batch_size is None and num_workers != 0:
+ warnings.warn("num_workers has no effect when batch_size is None.")
+ self.num_workers = 0
+
+ # If batch size is None, pin_memory has no effect
+ if batch_size is None and pin_memory:
+ warnings.warn("pin_memory has no effect when batch_size is None.")
+ self.pin_memory = False
+
+ # Move domain discretisation into conditions subsets
+ self.problem.move_discretisation_into_conditions()
+
+ # If no splits are defined, use the default dataloaders
+ if train_size == 0:
+ self.train_dataloader = super().train_dataloader
+ if val_size == 0:
+ self.val_dataloader = super().val_dataloader
+ if test_size == 0:
+ self.test_dataloader = super().test_dataloader
+
+ # Otherwise, create the condition splits and initialize the creator
+ self._create_condition_splits(train_size, test_size)
+ self.creator = _Creator(
+ batching_mode=self.batching_mode,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ automatic_batching=self.automatic_batching,
+ num_workers=self.num_workers,
+ pin_memory=self.pin_memory,
+ conditions=self.problem.conditions,
+ )
+
+ def _create_condition_splits(self, train_size, test_size):
+ """
+ Create train/validation/test index splits for each condition.
+
+ Samples belonging to each condition are optionally shuffled before being
+ partitioned into train, validation, and test subsets according to the
+ specified split fractions.
+
+ :param float train_size: The fraction of samples assigned to the
+ training split. Must belong to the interval ``[0, 1]``.
+ :param float test_size: The fraction of samples assigned to the test
+ split. Must belong to the interval ``[0, 1]``.
+ """
+ # Initialize the dictionary to store the split idx for each condition
+ self.split_idxs = {}
+
+ # Iterate through conditions and create the splits
+ for condition_name, condition in self.problem.conditions.items():
+
+ # Get the total number of samples for the current condition
+ condition_length = len(condition)
+
+ # Generate shuffled or sequential indices for the condition samples
+ indices = (
+ torch.randperm(condition_length).tolist()
+ if self.shuffle
+ else list(range(condition_length))
+ )
+
+ # Compute the split indices for train, validation, and test subsets
+ train_end = int(train_size * condition_length)
+ test_end = train_end + int(test_size * condition_length)
+
+ # Store the computed split indices in the dictionary
+ self.split_idxs[condition_name] = {
+ "train": indices[:train_end],
+ "test": indices[train_end:test_end],
+ "val": indices[test_end:],
+ }
+
+ def setup(self, stage=None):
+ """
+ Create dataset subsets for the requested execution stage.
+
+ Depending on the selected stage, it initializes the ``train_datasets``,
+ the ``val_datasets``, or the ``test_datasets`` attributes. Each dataset
+ is represented as a mapping between condition names and
+ :class:`_ConditionSubset` instances.
+
+ :param str stage: The execution stage. Available options are ``"fit"``
+ for training/validation and ``"test"`` for testing. If ``None``, both
+ training/validation and testing datasets are created.
+ Default is ``None``.
+ :raises ValueError: If the provided stage is invalid.
+ """
+ # Validate the stage argument
+ if stage not in ("fit", "test", None):
+ raise ValueError(
+ f"Invalid stage. Got {stage}, expected either 'fit' or 'test'."
+ )
+
+ # Fit stage: create training and validation datasets
+ if stage in ("fit", None):
+
+ # Train dataset
+ self.train_datasets = {
+ name: _ConditionSubset(
+ condition,
+ self.split_idxs[name]["train"],
+ automatic_batching=self.automatic_batching,
+ )
+ for name, condition in self.problem.conditions.items()
+ if len(self.split_idxs[name]["train"]) > 0
+ }
+
+ # Validation dataset
+ self.val_datasets = {
+ name: _ConditionSubset(
+ condition,
+ self.split_idxs[name]["val"],
+ automatic_batching=self.automatic_batching,
+ )
+ for name, condition in self.problem.conditions.items()
+ if len(self.split_idxs[name]["val"]) > 0
+ }
+
+ # Test stage: create testing dataset
+ if stage in ("test", None):
+
+ # Test dataset
+ self.test_datasets = {
+ name: _ConditionSubset(
+ condition,
+ self.split_idxs[name]["test"],
+ automatic_batching=self.automatic_batching,
+ )
+ for name, condition in self.problem.conditions.items()
+ if len(self.split_idxs[name]["test"]) > 0
+ }
+
+ def transfer_batch_to_device(self, batch, device, _):
+ """
+ Transfer a batch to the target device.
+
+ The method transfers all condition batches contained in the aggregated
+ batch dictionary to the specified device.
+
+ :param dict batch: The mapping between the condition names and the
+ condition batches.
+ :param torch.device device: The target device.
+ :param _: Placeholder argument, not used.
+ :return: A list of tuples containing condition names and transferred
+ batches.
+ :rtype: list[tuple[str, Any]]
+ """
+ return [
+ (condition_name, condition.to(device))
+ for condition_name, condition in batch.items()
+ ]
+
+ def train_dataloader(self):
+ """
+ Create the aggregated train dataloader.
+
+ :return: The aggregated dataloader coordinating all train condition
+ dataloaders.
+ :rtype: _Aggregator
+ """
+ return _Aggregator(
+ self.creator(self.train_datasets),
+ batching_mode=self.batching_mode,
+ )
+
+ def val_dataloader(self):
+ """
+ Create the aggregated validation dataloader.
+
+ :return: The aggregated dataloader coordinating all validation condition
+ dataloaders.
+ :rtype: _Aggregator
+ """
+ return _Aggregator(
+ self.creator(self.val_datasets), batching_mode=self.batching_mode
+ )
+
+ def test_dataloader(self):
+ """
+ Create the aggregated test dataloader.
+
+ :return: The aggregated dataloader coordinating all test condition
+ dataloaders.
+ :rtype: _Aggregator
+ """
+ return _Aggregator(
+ self.creator(self.test_datasets),
+ batching_mode=self.batching_mode,
+ )
diff --git a/pina/_src/data/manager/batch_manager.py b/pina/_src/data/manager/batch_manager.py
new file mode 100644
index 000000000..cdea44616
--- /dev/null
+++ b/pina/_src/data/manager/batch_manager.py
@@ -0,0 +1,47 @@
+"""Module for the Batch Manager class."""
+
+
+class _BatchManager(dict):
+ """
+ Dict-like container for batched data with attribute-style access and
+ convenience methods for device placement.
+ """
+
+ def to(self, device):
+ """
+ Move all compatible values in the batch to the specified device.
+
+ :param device: The target device.
+ :type device: torch.device | str
+ :return: The updated batch manager.
+ :rtype: _BatchManager
+ """
+ for key, value in self.items():
+ if hasattr(value, "to"):
+ moved_value = value.to(device)
+ self[key] = moved_value
+
+ return self
+
+ def __getattribute__(self, name):
+ """
+ Provide attribute-style access to dictionary keys.
+
+ :param str name: The name of the attribute to retrieve.
+ :raises AttributeError: If the attribute is not found as a standard
+ attribute or a dictionary key.
+ :return: The value associated with the attribute name.
+ :rtype: Any
+ """
+ # First, attempt to retrieve the attribute using the standard method.
+ try:
+ return super().__getattribute__(name)
+
+ # If not found, attempt to retrieve the attribute as a dictionary key.
+ except AttributeError:
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(
+ f"'BatchManager' object has no attribute '{name}'"
+ )
diff --git a/pina/_src/data/manager/data_manager.py b/pina/_src/data/manager/data_manager.py
new file mode 100644
index 000000000..3fd976d1d
--- /dev/null
+++ b/pina/_src/data/manager/data_manager.py
@@ -0,0 +1,50 @@
+"""Module for the Data Manager factory class."""
+
+import torch
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.equation.base_equation import BaseEquation
+from pina._src.data.manager.graph_data_manager import _GraphDataManager
+from pina._src.data.manager.tensor_data_manager import _TensorDataManager
+
+
+class _DataManager:
+ """
+ Factory class for data manager implementations.
+
+ This class dispatches object creation to either
+ :class:`~pina.data.manager.tensor_data_manager._TensorDataManager` or
+ :class:`~pina.data.manager.graph_data_manager._GraphDataManager` depending
+ on the types of the provided keyword arguments.
+ """
+
+ def __new__(cls, **kwargs):
+ """
+ Create the appropriate data manager implementation based on the provided
+ keyword arguments.
+
+ If all values in ``kwargs`` are instances of :class:`torch.Tensor`,
+ :class:`~pina.label_tensor.LabelTensor`, or
+ :class:`~pina.equation.base_equation.BaseEquation`, an instance of
+ :class:`~pina.data.manager.tensor_data_manager._TensorDataManager` is
+ created. Otherwise, an instance of
+ :class:`~pina.data.manager.graph_data_manager._GraphDataManager` is
+ created.
+
+ :param dict kwargs: The keyword arguments for the data manager.
+ :return: A concrete data manager instance.
+ :rtype: _TensorDataManager | _GraphDataManager
+ """
+ # Guard subclass instantiation
+ if cls is not _DataManager:
+ return super().__new__(cls)
+
+ # Check if there are only tensors / equations
+ is_tensor_only = all(
+ isinstance(v, (torch.Tensor, LabelTensor, BaseEquation))
+ for v in kwargs.values()
+ )
+
+ # Choose the appropriate subclass
+ subclass = _TensorDataManager if is_tensor_only else _GraphDataManager
+
+ return subclass(**kwargs)
diff --git a/pina/_src/data/manager/data_manager_interface.py b/pina/_src/data/manager/data_manager_interface.py
new file mode 100644
index 000000000..41b841e39
--- /dev/null
+++ b/pina/_src/data/manager/data_manager_interface.py
@@ -0,0 +1,53 @@
+"""Module for the Data Manager interface."""
+
+from abc import ABCMeta, abstractmethod
+
+
+class _DataManagerInterface(metaclass=ABCMeta):
+ """
+ Abstract interface for all data managers.
+ """
+
+ @abstractmethod
+ def __len__(self):
+ """
+ Return the number of samples in the data manager.
+
+ :return: The number of samples.
+ :rtype: int
+ """
+
+ @abstractmethod
+ def __getitem__(self, idx):
+ """
+ Return the item at the specified indices.
+
+ :param idx: The indices of the data point to retrieve.
+ :type idx: int | slice | list[int] | torch.Tensor
+ :return: A new :class:`_DataManager` instance containing the
+ selected data items.
+ :rtype: _DataManager
+ """
+
+ @abstractmethod
+ def to_batch(self):
+ """
+ Create a batch from the current data manager.
+
+ :return: A new :class:`~pina.condition.data_manager._DataManager`
+ instance with batched data.
+ :rtype: _DataManager
+ """
+
+ @staticmethod
+ @abstractmethod
+ def create_batch(items):
+ """
+ Create a batch from a list of :class:`_DataManager` items.
+
+ :param list[_DataManager] items: A list of
+ :class:`_DataManager` items to batch.
+ :return: A new instance of :class:`_DataManager` containing the
+ batched data.
+ :rtype: _DataManager
+ """
diff --git a/pina/_src/data/manager/graph_data_manager.py b/pina/_src/data/manager/graph_data_manager.py
new file mode 100644
index 000000000..660c75f83
--- /dev/null
+++ b/pina/_src/data/manager/graph_data_manager.py
@@ -0,0 +1,246 @@
+"""Module for the Graph-Data Manager class."""
+
+import torch
+from torch_geometric.data import Data
+from torch_geometric.data.batch import Batch
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.core.graph import Graph, LabelBatch
+from pina._src.data.manager.batch_manager import _BatchManager
+from pina._src.data.manager.data_manager_interface import _DataManagerInterface
+
+
+class _GraphDataManager(_DataManagerInterface):
+ """
+ Data manager for graph-based data. It handles inputs stored as
+ :class:`Graph`, :class:`Data`, or lists / tuples of these types.
+ """
+
+ def __init__(self, **kwargs):
+ """
+ Initialization of the :class:`_GraphDataManager` class.
+
+ :param dict kwargs: The keyword arguments for the graph data manager.
+ """
+ # Initialize keys
+ self.keys = list(kwargs.keys())
+
+ # Find graph-based data
+ self.graph_key = next(
+ k
+ for k, v in kwargs.items()
+ if isinstance(v, (Graph, Data, list, tuple))
+ )
+
+ # Find tensor data
+ self.keys = [
+ k
+ for k in self.keys
+ if k != self.graph_key
+ and isinstance(kwargs[k], (torch.Tensor, LabelTensor))
+ ]
+
+ # Prepare graphs and assign tensors
+ self.data = self._prepare_graphs(kwargs)
+
+ def __len__(self):
+ """
+ Return the number of samples in the graph data manager.
+
+ :return: The number of samples.
+ :rtype: int
+ """
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ """
+ Return the item at the specified indices.
+
+ :param idx: The indices of the graphs to retrieve.
+ :type idx: int | slice | list[int] | torch.Tensor
+ :raises TypeError: If an index with invalid type is passed.
+ :return: A new :class:`_GraphDataManager` instance containing the
+ selected graphs.
+ :rtype: _GraphDataManager
+ """
+ # Selection for integers or slices
+ if isinstance(idx, (int, slice)):
+ selected = self.data[idx]
+
+ # Selection for lists or tensors
+ elif isinstance(idx, (list, torch.Tensor)):
+ selected = [self.data[i] for i in idx]
+
+ # Raise TypeError if index type is invalid
+ else:
+ raise TypeError(f"Invalid index type: {type(idx)}")
+
+ # Ensure selected is a list
+ if not isinstance(selected, list):
+ selected = [selected]
+
+ return _GraphDataManager._init_from_graphs_list(
+ selected, graph_key=self.graph_key, keys=self.keys
+ )
+
+ def __getattr__(self, name):
+ """
+ Provide dynamic access to stored graph and tensor data.
+
+ If ``name`` corresponds to the graph key, return the list of graph
+ objects. If it matches a tensor key, retrieve the corresponding
+ tensors from all graphs and stack them along the batch dimension.
+
+ :param str name: The name of the attribute to access.
+ :return: The requested graph data or stacked tensor values.
+ :rtype: torch.Tensor | LabelTensor | list[Graph] | list[Data]
+ """
+ # Stack tensors from all graph if name is a tensor key
+ if name in self.keys:
+ tensors = [getattr(g, name) for g in self.data]
+ batch_fn = (
+ LabelTensor.stack
+ if isinstance(tensors[0], LabelTensor)
+ else torch.stack
+ )
+ return batch_fn(tensors)
+
+ # Otherwise, return graphs
+ if name == self.graph_key:
+ return self.data if len(self.data) > 1 else self.data[0]
+
+ return super().__getattribute__(name)
+
+ def _prepare_graphs(self, kwargs):
+ """
+ Attach tensor data to the corresponding graph objects.
+
+ :param kwargs: The keyword arguments containing graph data and
+ associated tensor features.
+ :raises ValueError: If the number of graphs does not match the number of
+ samples in the tensor of features to associate.
+ :return: A list of graphs with the corresponding tensors assigned.
+ :rtype: list[Graph] | list[Data]
+ """
+ # Get graph-based data and store in a list
+ graphs = kwargs.pop(self.graph_key)
+ if not isinstance(graphs, (list, tuple)):
+ graphs = [graphs]
+
+ # Iterate of items
+ for name, tensor in kwargs.items():
+
+ # Verify the consistency between the number of graphs and samples
+ if len(graphs) != tensor.shape[0]:
+ raise ValueError(
+ f"Number of graphs ({len(graphs)}) does not match "
+ f"number of samples for key '{name}' "
+ f"({kwargs[name].shape[0]})."
+ )
+
+ # Assign tensors to graphs
+ for i, g in enumerate(graphs):
+ setattr(g, name, tensor[i])
+
+ return graphs
+
+ def to_batch(self):
+ """
+ Create a batch from the current graph data manager.
+
+ :return: A new instance of :class:`_BatchManager` with batched data.
+ :rtype: _BatchManager
+ """
+ # Define the batch function
+ batching_fn = (
+ LabelBatch.from_data_list
+ if isinstance(self.data[0], Graph)
+ else Batch.from_data_list
+ )
+
+ # Create the batch manager
+ batch_data = _BatchManager()
+ batched_graph = batching_fn(self.data)
+ for k in self.keys:
+ if k == self.graph_key:
+ continue
+ batch_data[k] = getattr(batched_graph, k)
+ delattr(batched_graph, k)
+ batch_data[self.graph_key] = batched_graph
+
+ return batch_data
+
+ @staticmethod
+ def create_batch(items):
+ """
+ Create a batch from a list of :class:`_GraphDataManager` items.
+
+ :param list[_GraphDataManager] items: A list of
+ :class:`_GraphDataManager` items to batch.
+ :return: A new instance of :class:`_BatchManager` containing the batched
+ data.
+ :rtype: _BatchManager
+ """
+ # Return None if no items are provided
+ if not items:
+ return None
+
+ # Retrieve the first _GraphDataManager of the list and corresponding key
+ first = items[0]
+ graph_key = first.graph_key
+
+ # Initialize the batch manager
+ batch_data = _BatchManager()
+
+ # Define batch function
+ batching_fn = (
+ LabelBatch.from_data_list
+ if isinstance(first.data[0], Graph)
+ else Batch.from_data_list
+ )
+
+ # Batch over graphs
+ batched_graph = batching_fn([item.data[0] for item in items])
+
+ # Use a set for O(1) lookups if keys are large
+ keys_to_transfer = set(first.keys)
+ if graph_key in keys_to_transfer:
+ keys_to_transfer.remove(graph_key)
+
+ # Iterate over the keys of the _GraphDataManager
+ for k in keys_to_transfer:
+
+ # Extract values
+ val = getattr(batched_graph, k, None)
+ if val is not None:
+ batch_data[k] = val
+ delattr(batched_graph, k)
+
+ # Assign key to batch
+ batch_data[graph_key] = batched_graph
+
+ return batch_data
+
+ @classmethod
+ def _init_from_graphs_list(cls, graphs, graph_key, keys):
+ """
+ Create a :class:`_GraphDataManager` instance directly from a list of
+ graph objects.
+
+ This method bypasses the standard initialization logic and is used
+ internally to construct new instances (e.g., subsets) from already
+ processed graph data.
+
+ :param list graphs: A list of graph objects.
+ :param str graph_key: The name of the attribute used to store the
+ graphs.
+ :param list keys: A list of tensor keys associated with the graphs.
+ :return: A new instance of :class:`_GraphDataManager`.
+ :rtype: _GraphDataManager
+ """
+ # Create a new instance without calling __init__
+ obj = _GraphDataManager.__new__(_GraphDataManager)
+ obj.graph_key = graph_key
+ obj.keys = keys
+ obj.data = graphs
+
+ return obj
diff --git a/pina/_src/data/manager/tensor_data_manager.py b/pina/_src/data/manager/tensor_data_manager.py
new file mode 100644
index 000000000..2e530c40f
--- /dev/null
+++ b/pina/_src/data/manager/tensor_data_manager.py
@@ -0,0 +1,110 @@
+"""Module for the Tensor-Data Manager class."""
+
+import torch
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.data.manager.batch_manager import _BatchManager
+from pina._src.data.manager.data_manager_interface import _DataManagerInterface
+
+
+class _TensorDataManager(_DataManagerInterface):
+ """
+ Data manager for tensor-based data. It handles inputs stored as
+ :class:`torch.Tensor` or :class:`~pina.label_tensor.LabelTensor`.
+ """
+
+ def __init__(self, **kwargs):
+ """
+ Initialization of the :class:`_TensorDataManager` class.
+
+ :param dict kwargs: The keyword arguments for the tensor data manager.
+ """
+ self.keys = list(kwargs.keys())
+ self.data = kwargs
+
+ # Set attributes from kwargs
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+
+ def __len__(self):
+ """
+ Return the number of samples in the tensor data manager.
+
+ :return: The number of samples.
+ :rtype: int
+ """
+ return self.data[self.keys[0]].shape[0]
+
+ def __getitem__(self, idx):
+ """
+ Return the item at the specified indices.
+
+ :param idx: The indices of the data point to retrieve.
+ :type idx: int | slice | list[int] | torch.Tensor
+ :return: A new :class:`_TensorDataManager` instance containing the
+ selected data items.
+ :rtype: _TensorDataManager
+ """
+ # Get data at selected indices
+ new_data = {
+ k: (self.data[k][idx] if k in self.keys else self.data[k])
+ for k in self.keys
+ }
+
+ return _TensorDataManager(**new_data)
+
+ def to_batch(self):
+ """
+ Create a batch from the current tensor data manager.
+
+ :return: A new instance of :class:`_BatchManager` with batched data.
+ :rtype: _BatchManager
+ """
+ # Create the batch manager
+ batch_data = _BatchManager()
+ for k in self.keys:
+ batch_data[k] = self.data[k]
+
+ return batch_data
+
+ @staticmethod
+ def create_batch(items):
+ """
+ Create a batch from a list of :class:`_TensorDataManager` items.
+
+ :param list[_TensorDataManager] items: A list of
+ :class:`_TensorDataManager` items to batch.
+ :return: A new instance of :class:`_BatchManager` containing the batched
+ data.
+ :rtype: _BatchManager
+ """
+ # Return None if no items are provided
+ if not items:
+ return None
+
+ # Retrieve the first _TensorDataManager of the list
+ first = items[0]
+
+ # Initialize the batch manager
+ batch_data = _BatchManager()
+
+ # Iterate over the keys of the _TensorDataManager
+ for k in first.keys:
+
+ # Extract values and a sample used to determine the batch function
+ vals = [it.data[k] for it in items]
+ sample = vals[0]
+
+ # Define the batch function based on the data type
+ if isinstance(sample, (torch.Tensor, LabelTensor)):
+ batch_fn = (
+ LabelTensor.stack
+ if isinstance(sample, LabelTensor)
+ else torch.stack
+ )
+ batch_data[k] = batch_fn(vals)
+
+ # If no tensor is provided, just take the first value
+ else:
+ batch_data[k] = sample
+
+ return batch_data
diff --git a/pina/_src/data/single_batch_data_loader.py b/pina/_src/data/single_batch_data_loader.py
new file mode 100644
index 000000000..bec4cf93e
--- /dev/null
+++ b/pina/_src/data/single_batch_data_loader.py
@@ -0,0 +1,106 @@
+"""Module for the Single-Batch Data Loader class."""
+
+import torch
+
+
+class _SingleBatchDataLoader:
+ """
+ Data loader wrapper that returns the entire dataset as a single batch.
+
+ This utility is intended for cases where mini-batching is disabled (e.g.
+ ``batch_size=None``). The loader yields exactly one batch per iteration.
+
+ In distributed environments, the dataset is automatically partitioned across
+ processes according to the current rank and world size. Each process
+ receives only its corresponding subset of data.
+
+ In non-distributed environments, the full dataset is returned.
+ """
+
+ def __init__(self, dataset):
+ """
+ Initialization of the :class:`_SingleBatchDataLoader` class.
+
+ In distributed training, the dataset indices are split across processes
+ using the current rank and world size, so that each process receives
+ only its corresponding subset of data.
+
+ In non-distributed training, the full dataset is loaded.
+
+ The resulting data is converted into a single batch and stored
+ internally.
+
+ :param dataset: Dataset object.
+ :raises RuntimeError: If the dataset size is smaller than the number of
+ distributed processes.
+ """
+ # Initialize the flag to track if the batch has been yielded
+ self._has_yielded = False
+
+ # Distributed training
+ if (
+ torch.distributed.is_available()
+ and torch.distributed.is_initialized()
+ ):
+ # Get rank and world_size
+ rank = torch.distributed.get_rank()
+ world_size = torch.distributed.get_world_size()
+
+ # Raise runtime error if the dataset is smaller than the world size
+ if len(dataset) < world_size:
+ raise RuntimeError(
+ "Dataset size is smaller than the distributed world size. "
+ "Increase the dataset size or use a single GPU."
+ )
+
+ # Select dataset idx assigned to the current distributed process
+ idx, i = [], rank
+ while i < len(dataset):
+ idx.append(i)
+ i += world_size
+
+ # Fetch the process-specific subset
+ self.dataset = dataset.fetch_from_idx_list(idx).to_batch()
+
+ # Non-distributed training
+ else:
+ self.dataset = dataset.get_all_data().to_batch()
+
+ def __iter__(self):
+ """
+ Return the data loader iterator.
+
+ :return: The data loader instance itself.
+ :rtype: _SingleBatchDataLoader
+ """
+ # Reset the flag to yield the batch again if iterator is restarted
+ self._has_yielded = False
+ return self
+
+ def __len__(self):
+ """
+ Return the number of batches produced by the data loader.
+
+ Since the entire dataset is returned as a single batch, the length is
+ always ``1``.
+
+ :return: The number of batches.
+ :rtype: int
+ """
+ return 1
+
+ def __next__(self):
+ """
+ Return the next batch.
+
+ :return: The dataset converted into a single batch.
+ :rtype: _BatchManager
+ """
+ # Yield the batch only once per iteration
+ if self._has_yielded:
+ raise StopIteration
+
+ # Set the flag to indicate that the batch has been yielded
+ self._has_yielded = True
+
+ return self.dataset
diff --git a/pina/_src/domain/__init__.py b/pina/_src/domain/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pina/domain/base_domain.py b/pina/_src/domain/base_domain.py
similarity index 94%
rename from pina/domain/base_domain.py
rename to pina/_src/domain/base_domain.py
index c7bef9700..d3cea3848 100644
--- a/pina/domain/base_domain.py
+++ b/pina/_src/domain/base_domain.py
@@ -2,8 +2,8 @@
from copy import deepcopy
from abc import ABCMeta
-from .domain_interface import DomainInterface
-from ..utils import check_consistency, check_positive_integer
+from pina._src.domain.domain_interface import DomainInterface
+from pina._src.core.utils import check_consistency, check_positive_integer
class BaseDomain(DomainInterface, metaclass=ABCMeta):
@@ -103,8 +103,17 @@ def update(self, domain):
f"with domain of type {type(domain)}."
)
- # Update fixed and ranged variables
+ # Create a deepcopy of the current domain
updated = deepcopy(self)
+
+ # Remove keys that change category
+ for key in domain.fixed:
+ updated.range.pop(key, None)
+
+ for key in domain.range:
+ updated.fixed.pop(key, None)
+
+ # Update fixed and ranged variables
updated.fixed.update(domain.fixed)
updated.range.update(domain.range)
diff --git a/pina/domain/base_operation.py b/pina/_src/domain/base_operation.py
similarity index 97%
rename from pina/domain/base_operation.py
rename to pina/_src/domain/base_operation.py
index 8261ae431..ff83e1551 100644
--- a/pina/domain/base_operation.py
+++ b/pina/_src/domain/base_operation.py
@@ -2,9 +2,9 @@
from copy import deepcopy
from abc import ABCMeta
-from .operation_interface import OperationInterface
-from .base_domain import BaseDomain
-from ..utils import check_consistency
+from pina._src.domain.operation_interface import OperationInterface
+from pina._src.domain.base_domain import BaseDomain
+from pina._src.core.utils import check_consistency
class BaseOperation(OperationInterface, BaseDomain, metaclass=ABCMeta):
diff --git a/pina/domain/cartesian_domain.py b/pina/_src/domain/cartesian_domain.py
similarity index 97%
rename from pina/domain/cartesian_domain.py
rename to pina/_src/domain/cartesian_domain.py
index 3333a8fc3..089e3377c 100644
--- a/pina/domain/cartesian_domain.py
+++ b/pina/_src/domain/cartesian_domain.py
@@ -1,10 +1,10 @@
"""Module for the Cartesian Domain."""
import torch
-from .base_domain import BaseDomain
-from .union import Union
-from ..utils import torch_lhs, chebyshev_roots, check_consistency
-from ..label_tensor import LabelTensor
+from pina._src.domain.base_domain import BaseDomain
+from pina._src.domain.union import Union
+from pina._src.core.utils import torch_lhs, chebyshev_roots, check_consistency
+from pina._src.core.label_tensor import LabelTensor
class CartesianDomain(BaseDomain):
diff --git a/pina/domain/difference.py b/pina/_src/domain/difference.py
similarity index 96%
rename from pina/domain/difference.py
rename to pina/_src/domain/difference.py
index 76807b035..ce87920e5 100644
--- a/pina/domain/difference.py
+++ b/pina/_src/domain/difference.py
@@ -1,8 +1,8 @@
"""Module for the Difference operation."""
-from .base_operation import BaseOperation
-from ..label_tensor import LabelTensor
-from ..utils import check_consistency
+from pina._src.domain.base_operation import BaseOperation
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.core.utils import check_consistency
class Difference(BaseOperation):
diff --git a/pina/domain/domain_interface.py b/pina/_src/domain/domain_interface.py
similarity index 100%
rename from pina/domain/domain_interface.py
rename to pina/_src/domain/domain_interface.py
diff --git a/pina/domain/ellipsoid_domain.py b/pina/_src/domain/ellipsoid_domain.py
similarity index 98%
rename from pina/domain/ellipsoid_domain.py
rename to pina/_src/domain/ellipsoid_domain.py
index ecb08e37c..402ec29a8 100644
--- a/pina/domain/ellipsoid_domain.py
+++ b/pina/_src/domain/ellipsoid_domain.py
@@ -2,9 +2,9 @@
from copy import deepcopy
import torch
-from .base_domain import BaseDomain
-from ..label_tensor import LabelTensor
-from ..utils import check_consistency
+from pina._src.domain.base_domain import BaseDomain
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.core.utils import check_consistency
class EllipsoidDomain(BaseDomain):
diff --git a/pina/domain/exclusion.py b/pina/_src/domain/exclusion.py
similarity index 97%
rename from pina/domain/exclusion.py
rename to pina/_src/domain/exclusion.py
index 59205f3a8..914e17086 100644
--- a/pina/domain/exclusion.py
+++ b/pina/_src/domain/exclusion.py
@@ -1,9 +1,9 @@
"""Module for the Exclusion set-operation."""
import random
-from .base_operation import BaseOperation
-from ..label_tensor import LabelTensor
-from ..utils import check_consistency
+from pina._src.domain.base_operation import BaseOperation
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.core.utils import check_consistency
class Exclusion(BaseOperation):
diff --git a/pina/domain/intersection.py b/pina/_src/domain/intersection.py
similarity index 96%
rename from pina/domain/intersection.py
rename to pina/_src/domain/intersection.py
index 105575df1..1b004556e 100644
--- a/pina/domain/intersection.py
+++ b/pina/_src/domain/intersection.py
@@ -1,9 +1,9 @@
"""Module for the Intersection operation."""
import random
-from .base_operation import BaseOperation
-from ..label_tensor import LabelTensor
-from ..utils import check_consistency
+from pina._src.domain.base_operation import BaseOperation
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.core.utils import check_consistency
class Intersection(BaseOperation):
diff --git a/pina/domain/operation_interface.py b/pina/_src/domain/operation_interface.py
similarity index 92%
rename from pina/domain/operation_interface.py
rename to pina/_src/domain/operation_interface.py
index 9be458972..357556105 100644
--- a/pina/domain/operation_interface.py
+++ b/pina/_src/domain/operation_interface.py
@@ -1,7 +1,7 @@
"""Module for the Operation Interface."""
from abc import ABCMeta, abstractmethod
-from .domain_interface import DomainInterface
+from pina._src.domain.domain_interface import DomainInterface
class OperationInterface(DomainInterface, metaclass=ABCMeta):
diff --git a/pina/domain/simplex_domain.py b/pina/_src/domain/simplex_domain.py
similarity index 98%
rename from pina/domain/simplex_domain.py
rename to pina/_src/domain/simplex_domain.py
index 9e3a3e58f..5dff002ce 100644
--- a/pina/domain/simplex_domain.py
+++ b/pina/_src/domain/simplex_domain.py
@@ -2,9 +2,9 @@
from copy import deepcopy
import torch
-from .base_domain import BaseDomain
-from ..label_tensor import LabelTensor
-from ..utils import check_consistency
+from pina._src.domain.base_domain import BaseDomain
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.core.utils import check_consistency
class SimplexDomain(BaseDomain):
diff --git a/pina/domain/union.py b/pina/_src/domain/union.py
similarity index 95%
rename from pina/domain/union.py
rename to pina/_src/domain/union.py
index df094bb82..eff137df3 100644
--- a/pina/domain/union.py
+++ b/pina/_src/domain/union.py
@@ -1,9 +1,9 @@
"""Module for the Union operation."""
import random
-from .base_operation import BaseOperation
-from ..label_tensor import LabelTensor
-from ..utils import check_consistency
+from pina._src.domain.base_operation import BaseOperation
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.core.utils import check_consistency
class Union(BaseOperation):
diff --git a/pina/_src/equation/__init__.py b/pina/_src/equation/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pina/_src/equation/base_equation.py b/pina/_src/equation/base_equation.py
new file mode 100644
index 000000000..4fff8dd3b
--- /dev/null
+++ b/pina/_src/equation/base_equation.py
@@ -0,0 +1,67 @@
+"""Module for the Base Equation."""
+
+from abc import ABCMeta, abstractmethod
+import torch
+
+
+class BaseEquation(metaclass=ABCMeta):
+ """
+ Base class for all equations, implementing common functionality.
+
+ Equations are fundamental components in PINA, representing mathematical
+ constraints that must be satisfied by the model outputs. They can be passed
+ to :class:`~pina.condition.condition.Condition` objects to define the
+ conditions under which the model is trained.
+
+ All specific equation types should inherit from this class and implement its
+ abstract methods.
+
+ This class is not meant to be instantiated directly.
+ """
+
+ @abstractmethod
+ def residual(self, input_, output_, params_):
+ """
+ Evaluate the equation residual at the given inputs.
+
+ :param LabelTensor input_: The input points where the residual is
+ computed.
+ :param LabelTensor output_: The output tensor, potentially produced by a
+ :class:`torch.nn.Module` instance.
+ :param dict params_: An optional dictionary of unknown parameters, used
+ in :class:`~pina.problem.inverse_problem.InverseProblem` settings.
+ If the equation is not related to an inverse problem, this should be
+ set to ``None``. Default is ``None``.
+ :return: The residual values of the equation.
+ :rtype: LabelTensor
+ """
+
+ def to(self, device):
+ """
+ Move all tensor attributes to the specified device.
+
+ :param torch.device device: The target device to move the tensors to.
+ :return: The instance moved to the specified device.
+ :rtype: BaseEquation
+ """
+ # Iterate over all attributes of the Equation
+ for key, val in self.__dict__.items():
+
+ # Move tensors in dictionaries to the specified device
+ if isinstance(val, dict):
+ self.__dict__[key] = {
+ k: v.to(device) if torch.is_tensor(v) else v
+ for k, v in val.items()
+ }
+
+ # Move tensors in lists to the specified device
+ elif isinstance(val, list):
+ self.__dict__[key] = [
+ v.to(device) if torch.is_tensor(v) else v for v in val
+ ]
+
+ # Move tensor attributes to the specified device
+ elif torch.is_tensor(val):
+ self.__dict__[key] = val.to(device)
+
+ return self
diff --git a/pina/_src/equation/equation.py b/pina/_src/equation/equation.py
new file mode 100644
index 000000000..d10da2bbe
--- /dev/null
+++ b/pina/_src/equation/equation.py
@@ -0,0 +1,65 @@
+"""Module for the Equation."""
+
+import inspect
+from pina._src.equation.base_equation import BaseEquation
+
+
+class Equation(BaseEquation):
+ """
+ Implementation of the Equation class, representing a single mathematical
+ equation to be satisfied by the model outputs.
+
+ It can be passed to a :class:`~pina.condition.condition.Condition` object to
+ define the conditions under which the model is trained.
+ """
+
+ def __init__(self, equation):
+ """
+ Initialization of the :class:`Equation` class.
+
+ :param Callable equation: A callable function used to compute the
+ residual of a mathematical equation.
+ :raises ValueError: If the equation is not a callable function.
+ """
+ # Check consistency
+ if not callable(equation):
+ raise ValueError(f"Expected a callable function, got {equation}")
+
+ # Compute the signature length
+ sig = inspect.signature(equation)
+ self.__len_sig = len(sig.parameters)
+ self.__equation = equation
+
+ def residual(self, input_, output_, params_=None):
+ """
+ Evaluate the equation residual at the given inputs.
+
+ :param LabelTensor input_: The input points where the residual is
+ computed.
+ :param LabelTensor output_: The output tensor, potentially produced by a
+ :class:`torch.nn.Module` instance.
+ :param dict params_: An optional dictionary of unknown parameters, used
+ in :class:`~pina.problem.inverse_problem.InverseProblem` settings.
+ If the equation is not related to an inverse problem, this should be
+ set to ``None``. Default is ``None``.
+ :raises RuntimeError: If the underlying equation signature is neither of
+ length 2 for direct problems nor of length 3 for inverse problems.
+ :return: The residual values of the equation.
+ :rtype: LabelTensor
+ """
+ # Move the equation to the input_ device
+ self.to(input_.device)
+
+ # Evaluate the equation for direct problems
+ if self.__len_sig == 2:
+ return self.__equation(input_, output_)
+
+ # Evaluate the equation for inverse problems
+ if self.__len_sig == 3:
+ return self.__equation(input_, output_, params_)
+
+ # Raise an error if the signature length is unexpected
+ raise RuntimeError(
+ f"Unexpected number of arguments in equation: {self.__len_sig}. "
+ "Expected either 2 for direct problems, or 3 for inverse problems."
+ )
diff --git a/pina/_src/equation/equation_interface.py b/pina/_src/equation/equation_interface.py
new file mode 100644
index 000000000..fa59de678
--- /dev/null
+++ b/pina/_src/equation/equation_interface.py
@@ -0,0 +1,36 @@
+"""Module for the Equation Interface."""
+
+from abc import ABCMeta, abstractmethod
+
+
+class EquationInterface(metaclass=ABCMeta):
+ """
+ Abstract interface for all equations.
+ """
+
+ @abstractmethod
+ def residual(self, input_, output_, params_=None):
+ """
+ Evaluate the equation residual at the given inputs.
+
+ :param LabelTensor input_: The input points where the residual is
+ computed.
+ :param LabelTensor output_: The output tensor, potentially produced by a
+ :class:`torch.nn.Module` instance.
+ :param dict params_: An optional dictionary of unknown parameters, used
+ in :class:`~pina.problem.inverse_problem.InverseProblem` settings.
+ If the equation is not related to an inverse problem, this should be
+ set to ``None``. Default is ``None``.
+ :return: The residual values of the equation.
+ :rtype: LabelTensor
+ """
+
+ @abstractmethod
+ def to(self, device):
+ """
+ Move all tensor attributes to the specified device.
+
+ :param torch.device device: The target device to move the tensors to.
+ :return: The instance moved to the specified device.
+ :rtype: EquationInterface
+ """
diff --git a/pina/_src/equation/system_equation.py b/pina/_src/equation/system_equation.py
new file mode 100644
index 000000000..7d3bdafd4
--- /dev/null
+++ b/pina/_src/equation/system_equation.py
@@ -0,0 +1,118 @@
+"""Module for the System of Equation."""
+
+from typing import Callable
+import torch
+from pina._src.equation.base_equation import BaseEquation
+from pina._src.core.utils import check_consistency
+from pina._src.equation.equation import Equation
+
+
+class SystemEquation(BaseEquation):
+ """
+ Implementation of the SystemEquation class, representing a system of
+ mathematical equation to be satisfied by the model outputs. It is useful for
+ multi-component outputs or coupled problems, where multiple constraints must
+ be evaluated together.
+
+ It can be passed to a :class:`~pina.condition.condition.Condition` object to
+ define the conditions under which the model is trained.
+
+ Each equation in the system must be either an instance of
+ :class:`~pina.equation.equation.Equation`, or a callable function.
+
+ Residuals are computed independently for each equation and then aggregated
+ using an optional reduction (e.g., ``mean``, ``sum``). The final result is
+ returned as a single :class:`~pina.LabelTensor`.
+
+ :Example:
+
+ >>> pts = LabelTensor(torch.rand(10, 2), labels=["x", "y"])
+ >>> pts.requires_grad = True
+ >>> output_ = torch.pow(pts, 2)
+ >>> output_.labels = ["u", "v"]
+ >>> system_equation = SystemEquation(
+ ... [
+ ... FixedValue(value=1.0, components=["u"]),
+ ... FixedGradient(value=0.0, components=["v"], d=["y"]),
+ ... ],
+ ... reduction="mean",
+ ... )
+ >>> residual = system_equation.residual(pts, output_)
+ """
+
+ def __init__(self, list_equation, reduction=None):
+ """
+ Initialization of the :class:`SystemEquation` class.
+
+ :param list_equation: The list of equations used for the computation of
+ the residuals. Each element of the list can be either a callable
+ function or a :class:`~pina.equation.equation.Equation` instance.
+ :type list_equation: list[Callable] | list[Equation]
+ :param reduction: The method used to combine the residuals from each
+ equation. Available options are: ``None``, ``"mean"``, ``"sum"``, or
+ a custom callable. If ``None``, no reduction is applied. If
+ ``"mean"``, the residuals are averaged. If ``"sum"``, the residuals
+ are summed. If a callable is provided, it is used as a custom
+ reduction (no validation is performed).
+ :raises ValueError: If the list of equations is not a list.
+ :raises ValueError: If any element of the list of equations is not a
+ callable function or a :class:`~pina.equation.equation.Equation`
+ instance.
+ :raises ValueError: If an invalid reduction method is used.
+ """
+ # Check consistency
+ check_consistency([list_equation], list)
+ check_consistency(list_equation, (Callable, Equation))
+
+ # Convert all callable functions to Equation instances, if necessary
+ self.equations = [
+ equation if isinstance(equation, Equation) else Equation(equation)
+ for equation in list_equation
+ ]
+
+ # Validate and set the reduction method
+ if reduction == "mean":
+ self.reduction = torch.mean
+ elif reduction == "sum":
+ self.reduction = torch.sum
+ elif (reduction is None) or callable(reduction):
+ self.reduction = reduction
+ else:
+ raise ValueError(
+ "Invalid reduction method. Available options include: None, "
+ "'mean', 'sum', or a custom callable."
+ )
+
+ def residual(self, input_, output_, params_=None):
+ """
+ Evaluate each equation residual from the system of equations at the
+ given inputs and aggregate it according to the specified ``reduction``.
+
+ :param LabelTensor input_: The input points where the residual is
+ computed.
+ :param LabelTensor output_: The output tensor, potentially produced by a
+ :class:`torch.nn.Module` instance.
+ :param dict params_: An optional dictionary of unknown parameters, used
+ in :class:`~pina.problem.inverse_problem.InverseProblem` settings.
+ If the equation is not related to an inverse problem, this should be
+ set to ``None``. Default is ``None``.
+ :return: The aggregated residuals of the system of equations.
+ :rtype: LabelTensor
+ """
+ # Move the equation to the input_ device
+ self.to(input_.device)
+
+ # Compute the residual for each equation
+ residual = torch.cat(
+ [
+ equation.residual(input_, output_, params_)
+ for equation in self.equations
+ ],
+ dim=-1,
+ )
+
+ # Skip reduction if not specified
+ if self.reduction is None:
+ return residual
+
+ return self.reduction(residual, dim=-1)
diff --git a/pina/_src/equation/zoo/__init__.py b/pina/_src/equation/zoo/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pina/_src/equation/zoo/acoustic_wave_equation.py b/pina/_src/equation/zoo/acoustic_wave_equation.py
new file mode 100644
index 000000000..8a6d2bf07
--- /dev/null
+++ b/pina/_src/equation/zoo/acoustic_wave_equation.py
@@ -0,0 +1,62 @@
+"""Module for defining the acoustic wave equation."""
+
+from pina._src.equation.equation import Equation
+from pina._src.core.operator import laplacian
+from pina._src.core.utils import check_consistency
+
+
+class AcousticWaveEquation(Equation):
+ r"""
+ Implementation of the N-dimensional isotropic acoustic wave equation.
+ The equation is defined as follows:
+
+ .. math::
+
+ \frac{\partial^2 u}{\partial t^2} - c^2 \Delta u = 0
+
+ or alternatively:
+
+ .. math::
+
+ \Box u = 0
+
+ Here, :math:`c` is the wave propagation speed, and :math:`\Box` is the
+ d'Alembert operator.
+ """
+
+ def __init__(self, c):
+ """
+ Initialization of the :class:`AcousticWaveEquation` class.
+
+ :param c: The wave propagation speed.
+ :type c: float | int
+ """
+ check_consistency(c, (float, int))
+ self.c = c
+
+ def equation(input_, output_):
+ """
+ Implementation of the acoustic wave equation.
+
+ :param LabelTensor input_: The input data of the problem.
+ :param LabelTensor output_: The output data of the problem.
+ :return: The residual of the acoustic wave equation.
+ :rtype: LabelTensor
+ :raises ValueError: If the ``input_`` labels do not contain the time
+ variable 't'.
+ """
+ # Ensure time is passed as input
+ if "t" not in input_.labels:
+ raise ValueError(
+ "The ``input_`` labels must contain the time 't' variable."
+ )
+
+ # Compute the time second derivative and the spatial laplacian
+ u_tt = laplacian(output_, input_, d=["t"])
+ u_xx = laplacian(
+ output_, input_, d=[di for di in input_.labels if di != "t"]
+ )
+
+ return u_tt - self.c**2 * u_xx
+
+ super().__init__(equation)
diff --git a/pina/_src/equation/zoo/advection_equation.py b/pina/_src/equation/zoo/advection_equation.py
new file mode 100644
index 000000000..81e476bd5
--- /dev/null
+++ b/pina/_src/equation/zoo/advection_equation.py
@@ -0,0 +1,93 @@
+"""Module for defining the advection equation."""
+
+import torch
+from pina._src.equation.equation import Equation
+from pina._src.core.operator import grad
+from pina._src.core.utils import check_consistency
+
+
+class AdvectionEquation(Equation):
+ r"""
+ Implementation of the N-dimensional advection equation with constant
+ velocity parameter. The equation is defined as follows:
+
+ .. math::
+
+ \frac{\partial u}{\partial t} + c \cdot \nabla u = 0
+
+ Here, :math:`c` is the advection velocity parameter.
+ """
+
+ def __init__(self, c):
+ """
+ Initialization of the :class:`AdvectionEquation` class.
+
+ :param c: The advection velocity. If a scalar is provided, the same
+ velocity is applied to all spatial dimensions. If a list is
+ provided, it must contain one value per spatial dimension.
+ :type c: float | int | List[float] | List[int]
+ :raises ValueError: If ``c`` is an empty list.
+ """
+ # Check consistency
+ check_consistency(c, (float, int))
+ if isinstance(c, list):
+ if len(c) < 1:
+ raise ValueError("'c' cannot be an empty list.")
+ else:
+ c = [c]
+
+ # Store advection velocity parameter
+ self.c = torch.tensor(c).unsqueeze(0)
+
+ def equation(input_, output_):
+ """
+ Implementation of the advection equation.
+
+ :param LabelTensor input_: The input data of the problem.
+ :param LabelTensor output_: The output data of the problem.
+ :return: The residual of the advection equation.
+ :rtype: LabelTensor
+ :raises ValueError: If the ``input_`` labels do not contain the time
+ variable 't'.
+ :raises ValueError: If ``c`` is a list and its length is not
+ consistent with the number of spatial dimensions.
+ """
+ # Store labels
+ input_lbl = input_.labels
+ spatial_d = [di for di in input_lbl if di != "t"]
+
+ # Ensure time is passed as input
+ if "t" not in input_lbl:
+ raise ValueError(
+ "The ``input_`` labels must contain the time 't' variable."
+ )
+
+ # Ensure consistency of c length
+ if self.c.shape[-1] != len(input_lbl) - 1 and self.c.shape[-1] > 1:
+ raise ValueError(
+ "If 'c' is passed as a list, its length must be equal to "
+ "the number of spatial dimensions."
+ )
+
+ # Repeat c to ensure consistent shape for advection
+ c = self.c.repeat(output_.shape[0], 1)
+ if c.shape[1] != (len(input_lbl) - 1):
+ c = c.repeat(1, len(input_lbl) - 1)
+
+ # Add a dimension to c for the following operations
+ c = c.unsqueeze(-1)
+
+ # Compute the time derivative and the spatial gradient
+ time_der = grad(output_, input_, components=None, d="t")
+ grads = grad(output_=output_, input_=input_, d=spatial_d)
+
+ # Reshape and transpose
+ tmp = grads.reshape(*output_.shape, len(spatial_d))
+ tmp = tmp.transpose(-1, -2)
+
+ # Compute advection term
+ adv = (tmp * c).sum(dim=tmp.tensor.ndim - 2)
+
+ return time_der + adv
+
+ super().__init__(equation)
diff --git a/pina/_src/equation/zoo/allen_cahn_equation.py b/pina/_src/equation/zoo/allen_cahn_equation.py
new file mode 100644
index 000000000..e7091add2
--- /dev/null
+++ b/pina/_src/equation/zoo/allen_cahn_equation.py
@@ -0,0 +1,58 @@
+"""Module for defining the Allen-Cahn equation."""
+
+from pina._src.equation.equation import Equation
+from pina._src.core.operator import grad, laplacian
+from pina._src.core.utils import check_consistency
+
+
+class AllenCahnEquation(Equation):
+ r"""
+ Implementation of the N-dimensional Allen-Cahn equation, defined as follows:
+
+ .. math::
+
+ \frac{\partial u}{\partial t} - \alpha \Delta u + \beta(u^3 - u) = 0
+
+ Here, :math:`\alpha` and :math:`\beta` are parameters of the equation.
+ """
+
+ def __init__(self, alpha, beta):
+ """
+ Initialization of the :class:`AllenCahnEquation` class.
+
+ :param alpha: The diffusion coefficient.
+ :type alpha: float | int
+ :param beta: The reaction coefficient.
+ :type beta: float | int
+ """
+ check_consistency(alpha, (float, int))
+ check_consistency(beta, (float, int))
+ self.alpha = alpha
+ self.beta = beta
+
+ def equation(input_, output_):
+ """
+ Implementation of the Allen-Cahn equation.
+
+ :param LabelTensor input_: The input data of the problem.
+ :param LabelTensor output_: The output data of the problem.
+ :return: The residual of the Allen-Cahn equation.
+ :rtype: LabelTensor
+ :raises ValueError: If the ``input_`` labels do not contain the time
+ variable 't'.
+ """
+ # Ensure time is passed as input
+ if "t" not in input_.labels:
+ raise ValueError(
+ "The ``input_`` labels must contain the time 't' variable."
+ )
+
+ # Compute the time derivative and the spatial laplacian
+ u_t = grad(output_, input_, d=["t"])
+ u_xx = laplacian(
+ output_, input_, d=[di for di in input_.labels if di != "t"]
+ )
+
+ return u_t - self.alpha * u_xx + self.beta * (output_**3 - output_)
+
+ super().__init__(equation)
diff --git a/pina/_src/equation/zoo/burgers_equation.py b/pina/_src/equation/zoo/burgers_equation.py
new file mode 100644
index 000000000..07c8eed22
--- /dev/null
+++ b/pina/_src/equation/zoo/burgers_equation.py
@@ -0,0 +1,84 @@
+"""Module for defining the Burgers equation."""
+
+from pina._src.core.operator import laplacian, grad
+from pina._src.core.utils import check_consistency
+from pina._src.equation.equation import Equation
+import torch
+
+
+class BurgersEquation(Equation):
+ r"""
+ Implementation of the N-dimensional Burgers' equation, defined as follows:
+
+ .. math::
+
+ \frac{\partial u}{\partial t} + u \cdot \nabla u = \nu \Delta u
+
+ Here, :math:`\nu` is the viscosity coefficient.
+ """
+
+ def __init__(self, nu):
+ """
+ Initialization of the :class:`BurgersEquation` class.
+
+ :param nu: The viscosity coefficient.
+ :type nu: float | int
+ :raises ValueError: If ``nu`` is not a float or an int.
+ :raises ValueError: If ``nu`` is negative.
+ """
+ # Check consistency
+ check_consistency(nu, (float, int))
+ if nu < 0:
+ raise ValueError(
+ "The viscosity ``nu`` must be a non-negative float or int."
+ )
+
+ # Store viscosity coefficient
+ self.nu = nu
+
+ def equation(input_, output_):
+ """
+ Implementation of the Burgers' equation.
+
+ :param LabelTensor input_: The input data of the problem.
+ :param LabelTensor output_: The output data of the problem.
+ :raises ValueError: If the number of output components does not
+ match the number of spatial dimensions.
+ :raises ValueError: If the ``input_`` labels do not contain the time
+ variable 't'.
+ :return: The residual of the Burgers' equation.
+ :rtype: LabelTensor
+ """
+ # Store labels
+ spatial_d = [di for di in input_.labels if di != "t"]
+
+ # Ensure consistency between output and spatial dimensions
+ if len(output_.labels) != len(spatial_d):
+ raise ValueError(
+ f"The number of output components must match the number of "
+ f"spatial dimensions. Got {len(output_.labels)} and "
+ f"{len(spatial_d)}."
+ )
+
+ # Ensure time is passed as input
+ if "t" not in input_.labels:
+ raise ValueError(
+ "The ``input_`` labels must contain the time 't' variable."
+ )
+
+ # Compute the differential terms
+ u_t = grad(output_, input_, d=["t"])
+ u_x = grad(output_, input_, d=spatial_d)
+ u_xx = laplacian(output_, input_, d=spatial_d)
+
+ # Compute the convective term componentwise
+ convection = torch.zeros_like(output_)
+ for i, c in enumerate(output_.labels):
+ convection[:, i] = sum(
+ output_[output_.labels[j]] * u_x[f"d{c}d{spatial_d[j]}"]
+ for j in range(len(spatial_d))
+ ).reshape(-1)
+
+ return u_t + convection - self.nu * u_xx
+
+ super().__init__(equation)
diff --git a/pina/_src/equation/zoo/diffusion_reaction_equation.py b/pina/_src/equation/zoo/diffusion_reaction_equation.py
new file mode 100644
index 000000000..4f276dd54
--- /dev/null
+++ b/pina/_src/equation/zoo/diffusion_reaction_equation.py
@@ -0,0 +1,61 @@
+"""Module for defining the Diffusion-Reaction equation."""
+
+from typing import Callable
+from pina._src.equation.equation import Equation
+from pina._src.core.operator import grad, laplacian
+from pina._src.core.utils import check_consistency
+
+
+class DiffusionReactionEquation(Equation):
+ r"""
+ Implementation of the N-dimensional Diffusion-Reaction equation,
+ defined as follows:
+
+ .. math::
+
+ \frac{\partial u}{\partial t} - \alpha \Delta u - f = 0
+
+ Here, :math:`\alpha` is a parameter of the equation, while :math:`f` is the
+ reaction term.
+ """
+
+ def __init__(self, alpha, forcing_term):
+ """
+ Initialization of the :class:`DiffusionReactionEquation` class.
+
+ :param alpha: The diffusion coefficient.
+ :type alpha: float | int
+ :param Callable forcing_term: The forcing field function, taking as
+ input the points on which evaluation is required.
+ """
+ check_consistency(alpha, (float, int))
+ check_consistency(forcing_term, (Callable))
+ self.alpha = alpha
+ self.forcing_term = forcing_term
+
+ def equation(input_, output_):
+ """
+ Implementation of the Diffusion-Reaction equation.
+
+ :param LabelTensor input_: The input data of the problem.
+ :param LabelTensor output_: The output data of the problem.
+ :return: The residual of the Diffusion-Reaction equation.
+ :rtype: LabelTensor
+ :raises ValueError: If the ``input_`` labels do not contain the time
+ variable 't'.
+ """
+ # Ensure time is passed as input
+ if "t" not in input_.labels:
+ raise ValueError(
+ "The ``input_`` labels must contain the time 't' variable."
+ )
+
+ # Compute the time derivative and the spatial laplacian
+ u_t = grad(output_, input_, d=["t"])
+ u_xx = laplacian(
+ output_, input_, d=[di for di in input_.labels if di != "t"]
+ )
+
+ return u_t - self.alpha * u_xx - self.forcing_term(input_)
+
+ super().__init__(equation)
diff --git a/pina/_src/equation/zoo/fixed_flux.py b/pina/_src/equation/zoo/fixed_flux.py
new file mode 100644
index 000000000..858f3bdd1
--- /dev/null
+++ b/pina/_src/equation/zoo/fixed_flux.py
@@ -0,0 +1,54 @@
+"""Module for defining the fixed flux equation."""
+
+from pina._src.equation.equation import Equation
+from pina._src.core.operator import div
+from pina._src.core.utils import check_consistency
+
+
+class FixedFlux(Equation):
+ """
+ Equation to enforce a fixed flux, or divergence, for a specific condition.
+ """
+
+ def __init__(self, value, components=None, d=None):
+ """
+ Initialization of the :class:`FixedFlux` class.
+
+ :param value: The fixed value to be enforced to the flux.
+ :type value: float | int
+ :param components: The name of the output variables for which the fixed
+ flux condition is applied. It should be a subset of the output
+ labels. If ``None``, all output variables are considered. Default is
+ ``None``.
+ :type components: str | list[str]
+ :param d: The name of the input variables on which the flux is computed.
+ It should be a subset of the input labels. If ``None``, all the
+ input variables are considered. Default is ``None``.
+ :type d: str | list[str]
+ :raises ValueError: If ``value`` is neither a float nor an integer.
+ :raises ValueError: If, when provided, ``components`` is neither a
+ string nor a list of strings.
+ :raises ValueError: If, when provided, ``d`` is neither a string nor a
+ list of strings.
+ """
+ # Check consistency
+ check_consistency(value, (float, int))
+ if components is not None:
+ check_consistency(components, str)
+ if d is not None:
+ check_consistency(d, str)
+
+ def equation(input_, output_):
+ """
+ Definition of the equation to enforce a fixed flux.
+
+ :param LabelTensor input_: The input points where the residual is
+ computed.
+ :param LabelTensor output_: The output tensor, potentially produced
+ by a :class:`torch.nn.Module` instance.
+ :return: The residual values of the equation.
+ :rtype: LabelTensor
+ """
+ return div(output_, input_, components=components, d=d) - value
+
+ super().__init__(equation)
diff --git a/pina/_src/equation/zoo/fixed_gradient.py b/pina/_src/equation/zoo/fixed_gradient.py
new file mode 100644
index 000000000..2c60c007f
--- /dev/null
+++ b/pina/_src/equation/zoo/fixed_gradient.py
@@ -0,0 +1,53 @@
+"""Module for defining the fixed gradient equation."""
+
+from pina._src.equation.equation import Equation
+from pina._src.core.operator import grad
+from pina._src.core.utils import check_consistency
+
+
+class FixedGradient(Equation):
+ """
+ Equation to enforce a fixed gradient for a specific condition.
+ """
+
+ def __init__(self, value, components=None, d=None):
+ """
+ Initialization of the :class:`FixedGradient` class.
+
+ :param float value: The fixed value to be enforced to the gradient.
+ :param components: The name of the output variables for which the fixed
+ gradient condition is applied. It should be a subset of the output
+ labels. If ``None``, all output variables are considered. Default is
+ ``None``.
+ :type components: str | list[str]
+ :param d: The name of the input variables on which the gradient is
+ computed. It should be a subset of the input labels. If ``None``,
+ all the input variables are considered. Default is ``None``.
+ :type d: str | list[str]
+ :raises ValueError: If ``value`` is neither a float nor an integer.
+ :raises ValueError: If, when provided, ``components`` is neither a
+ string nor a list of strings.
+ :raises ValueError: If, when provided, ``d`` is neither a string nor a
+ list of strings.
+ """
+ # Check consistency
+ check_consistency(value, (float, int))
+ if components is not None:
+ check_consistency(components, str)
+ if d is not None:
+ check_consistency(d, str)
+
+ def equation(input_, output_):
+ """
+ Definition of the equation to enforce a fixed gradient.
+
+ :param LabelTensor input_: The input points where the residual is
+ computed.
+ :param LabelTensor output_: The output tensor, potentially produced
+ by a :class:`torch.nn.Module` instance.
+ :return: The residual values of the equation.
+ :rtype: LabelTensor
+ """
+ return grad(output_, input_, components=components, d=d) - value
+
+ super().__init__(equation)
diff --git a/pina/_src/equation/zoo/fixed_laplacian.py b/pina/_src/equation/zoo/fixed_laplacian.py
new file mode 100644
index 000000000..8d0fa7cf4
--- /dev/null
+++ b/pina/_src/equation/zoo/fixed_laplacian.py
@@ -0,0 +1,68 @@
+"""Module for defining the fixed laplacian equation."""
+
+import warnings
+from pina._src.equation.equation import Equation
+from pina._src.core.operator import laplacian
+from pina._src.core.utils import check_consistency
+
+
+class FixedLaplacian(Equation):
+ """
+ Equation to enforce a fixed laplacian for a specific condition.
+ """
+
+ def __init__(self, value, components=None, d=None):
+ """
+ Initialization of the :class:`FixedLaplacian` class.
+
+ :param value: The fixed value to be enforced to the laplacian.
+ :type value: float | int
+ :param components: The name of the output variables for which the fixed
+ laplace condition is applied. It should be a subset of the output
+ labels. If ``None``, all output variables are considered. Default is
+ ``None``.
+ :type components: str | list[str]
+ :param d: The name of the input variables on which the laplacian is
+ computed. It should be a subset of the input labels. If ``None``,
+ all the input variables are considered. Default is ``None``.
+ :type d: str | list[str]
+ :raises ValueError: If ``value`` is neither a float nor an integer.
+ :raises ValueError: If, when provided, ``components`` is neither a
+ string nor a list of strings.
+ :raises ValueError: If, when provided, ``d`` is neither a string nor a
+ list of strings.
+ """
+ # Check consistency
+ check_consistency(value, (float, int))
+ if components is not None:
+ check_consistency(components, str)
+ if d is not None:
+ check_consistency(d, str)
+
+ def equation(input_, output_):
+ """
+ Definition of the equation to enforce a fixed laplacian.
+
+ :param LabelTensor input_: The input points where the residual is
+ computed.
+ :param LabelTensor output_: The output tensor, potentially produced
+ by a :class:`torch.nn.Module` instance.
+ :return: The residual values of the equation.
+ :rtype: LabelTensor
+ """
+ return (
+ laplacian(output_, input_, components=components, d=d) - value
+ )
+
+ super().__init__(equation)
+
+
+# Back-compatibility with version 0.2, to be removed soon
+class Laplace(FixedLaplacian):
+ def __init__(self, components=None, d=None):
+ warnings.warn(
+ "Laplace is deprecated, use FixedLaplacian with value=0.0 instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ super().__init__(0.0, components=components, d=d)
diff --git a/pina/_src/equation/zoo/fixed_value.py b/pina/_src/equation/zoo/fixed_value.py
new file mode 100644
index 000000000..25c81c8b8
--- /dev/null
+++ b/pina/_src/equation/zoo/fixed_value.py
@@ -0,0 +1,48 @@
+"""Module for defining the fixed value equation."""
+
+from pina._src.equation.equation import Equation
+from pina._src.core.utils import check_consistency
+
+
+class FixedValue(Equation):
+ """
+ Equation to enforce a fixed value. Can be used to enforce Dirichlet Boundary
+ conditions.
+ """
+
+ def __init__(self, value, components=None):
+ """
+ Initialization of the :class:`FixedValue` class.
+
+ :param value: The fixed value to be enforced.
+ :type value: float | int
+ :param components: The name of the output variables for which the fixed
+ value condition is applied. It should be a subset of the output
+ labels. If ``None``, all output variables are considered. Default is
+ ``None``.
+ :type components: str | list[str]
+ :raises ValueError: If ``value`` is neither a float nor an integer.
+ :raises ValueError: If, when provided, ``components`` is neither a
+ string nor a list of strings.
+ """
+ # Check consistency
+ check_consistency(value, (float, int))
+ if components is not None:
+ check_consistency(components, str)
+
+ def equation(_, output_):
+ """
+ Definition of the equation to enforce a fixed value.
+
+ :param LabelTensor input_: The input points where the residual is
+ computed.
+ :param LabelTensor output_: The output tensor, potentially produced
+ by a :class:`torch.nn.Module` instance.
+ :return: The residual values of the equation.
+ :rtype: LabelTensor
+ """
+ if components is None:
+ return output_ - value
+ return output_.extract(components) - value
+
+ super().__init__(equation)
diff --git a/pina/_src/equation/zoo/helmholtz_equation.py b/pina/_src/equation/zoo/helmholtz_equation.py
new file mode 100644
index 000000000..57b353bf0
--- /dev/null
+++ b/pina/_src/equation/zoo/helmholtz_equation.py
@@ -0,0 +1,47 @@
+"""Module for defining the Helmholtz equation."""
+
+from typing import Callable
+from pina._src.equation.equation import Equation
+from pina._src.core.operator import laplacian
+from pina._src.core.utils import check_consistency
+
+
+class HelmholtzEquation(Equation):
+ r"""
+ Implementation of the Helmholtz equation, defined as follows:
+
+ .. math::
+
+ \Delta u + k u - f = 0
+
+ Here, :math:`k` is the squared wavenumber, while :math:`f` is the
+ forcing term.
+ """
+
+ def __init__(self, k, forcing_term):
+ """
+ Initialization of the :class:`HelmholtzEquation` class.
+
+ :param k: The squared wavenumber.
+ :type k: float | int
+ :param Callable forcing_term: The forcing field function, taking as
+ input the points on which evaluation is required.
+ """
+ check_consistency(k, (int, float))
+ check_consistency(forcing_term, (Callable))
+ self.k = k
+ self.forcing_term = forcing_term
+
+ def equation(input_, output_):
+ """
+ Implementation of the Helmholtz equation.
+
+ :param LabelTensor input_: The input data of the problem.
+ :param LabelTensor output_: The output data of the problem.
+ :return: The residual of the Helmholtz equation.
+ :rtype: LabelTensor
+ """
+ lap = laplacian(output_, input_)
+ return lap + self.k * output_ - self.forcing_term(input_)
+
+ super().__init__(equation)
diff --git a/pina/_src/equation/zoo/poisson_equation.py b/pina/_src/equation/zoo/poisson_equation.py
new file mode 100644
index 000000000..2ab80ff33
--- /dev/null
+++ b/pina/_src/equation/zoo/poisson_equation.py
@@ -0,0 +1,42 @@
+"""Module for defining the Poisson equation."""
+
+from typing import Callable
+from pina._src.equation.equation import Equation
+from pina._src.core.operator import laplacian
+from pina._src.core.utils import check_consistency
+
+
+class PoissonEquation(Equation):
+ r"""
+ Implementation of the Poisson equation, defined as follows:
+
+ .. math::
+
+ \Delta u - f = 0
+
+ Here, :math:`f` is the forcing term.
+ """
+
+ def __init__(self, forcing_term):
+ """
+ Initialization of the :class:`PoissonEquation` class.
+
+ :param Callable forcing_term: The forcing field function, taking as
+ input the points on which evaluation is required.
+ """
+ check_consistency(forcing_term, (Callable))
+ self.forcing_term = forcing_term
+
+ def equation(input_, output_):
+ """
+ Implementation of the Poisson equation.
+
+ :param LabelTensor input_: The input data of the problem.
+ :param LabelTensor output_: The output data of the problem.
+ :return: The residual of the Poisson equation.
+ :rtype: LabelTensor
+ """
+ lap = laplacian(output_, input_)
+ return lap - self.forcing_term(input_)
+
+ super().__init__(equation)
diff --git a/pina/_src/loss/__init__.py b/pina/_src/loss/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pina/_src/loss/base_dual_loss.py b/pina/_src/loss/base_dual_loss.py
new file mode 100644
index 000000000..9287142bc
--- /dev/null
+++ b/pina/_src/loss/base_dual_loss.py
@@ -0,0 +1,54 @@
+"""Module for the BaseDualLoss class."""
+
+import torch
+from pina._src.loss.dual_loss_interface import DualLossInterface
+
+
+class BaseDualLoss(DualLossInterface):
+ """
+ Base class for all losses requiring both an input and a target tensor,
+ implementing common functionality.
+
+ All specific loss types should inherit from this class and implement its
+ abstract methods.
+
+ This class is not meant to be instantiated directly.
+ """
+
+ # Define available reduction methods
+ _REDUCTION_METHOD = {
+ "sum": lambda x: torch.sum(x, keepdim=True, dim=-1),
+ "mean": lambda x: torch.mean(x, keepdim=True, dim=-1),
+ "none": lambda x: x,
+ }
+
+ def __init__(self, reduction="mean"):
+ """
+ Initialization of the :class:`BaseDualLoss` class.
+
+ :param str reduction: The reduction method to aggregate pointwise loss
+ values. Available options include: ``"none"`` for unreduced loss,
+ ``"mean"`` for the average of the loss values, and ``"sum"`` for
+ their total sum. Default is ``"mean"``.
+ :raises ValueError: If the specified reduction method is not among the
+ available options.
+ """
+ # Check that the reduction method is available
+ if reduction not in self._REDUCTION_METHOD:
+ raise ValueError(
+ f"Invalid reduction method. Available options: "
+ f"{list(self._REDUCTION_METHOD.keys())}. Got {reduction}."
+ )
+
+ # Initialization
+ super().__init__(reduction=reduction, size_average=None, reduce=None)
+
+ def _reduction(self, loss):
+ """
+ Apply the configured reduction operation to pointwise loss values.
+
+ :param torch.Tensor loss: The tensor of pointwise losses.
+ :return: The reduced loss tensor.
+ :rtype: torch.Tensor
+ """
+ return self._REDUCTION_METHOD[self.reduction](loss)
diff --git a/pina/_src/loss/dual_loss_interface.py b/pina/_src/loss/dual_loss_interface.py
new file mode 100644
index 000000000..6db6bc44f
--- /dev/null
+++ b/pina/_src/loss/dual_loss_interface.py
@@ -0,0 +1,32 @@
+"""Module for the Loss Interface."""
+
+from abc import ABCMeta, abstractmethod
+from torch.nn.modules.loss import _Loss
+
+
+class DualLossInterface(_Loss, metaclass=ABCMeta):
+ """
+ Abstract interface for all losses requiring both an input and a target
+ tensor.
+ """
+
+ @abstractmethod
+ def forward(self, input, target):
+ """
+ Forward method of the loss function.
+
+ :param torch.Tensor input: The input tensor.
+ :param torch.Tensor target: The target tensor.
+ :return: The computed loss.
+ :rtype: torch.Tensor
+ """
+
+ @abstractmethod
+ def _reduction(self, loss):
+ """
+ Apply the configured reduction operation to pointwise loss values.
+
+ :param torch.Tensor loss: The tensor of pointwise losses.
+ :return: The reduced loss tensor.
+ :rtype: torch.Tensor
+ """
diff --git a/pina/_src/loss/lp_loss.py b/pina/_src/loss/lp_loss.py
new file mode 100644
index 000000000..c2d25ea4e
--- /dev/null
+++ b/pina/_src/loss/lp_loss.py
@@ -0,0 +1,91 @@
+"""Module for the Lp Loss class."""
+
+import torch
+from pina._src.loss.base_dual_loss import BaseDualLoss
+from pina._src.core.utils import check_consistency
+
+
+class LpLoss(BaseDualLoss):
+ r"""
+ Implementation of the :math:`L^p` loss measuring the pointwise :math:`L^p`
+ distance between an input tensor :math:`x` and a target tensor :math:`y`.
+
+ Given a batch of size :math:`N` and feature dimension :math:`D`, the
+ unreduced loss (``reduction="none"``) is defined as:
+
+ .. math::
+ L = \{l_1, \dots, l_N\}^\top, \quad
+ l_n = \left( \sum_{i=1}^{D} \left| x_n^i - y_n^i \right|^p \right)^{1/p}
+
+ If ``relative=True``, each term is normalized by the :math:`L^p` norm of the
+ input tensor :math:`x`:
+
+ .. math::
+ l_n = \frac{\left( \sum_{i=1}^{D} |x_n^i - y_n^i|^p \right)^{1/p}}
+ {\left( \sum_{i=1}^{D} |x_n^i|^p \right)^{1/p}}
+
+ If ``reduction`` is set to ``"mean"`` or ``"sum"``, the vector :math:`L`
+ is aggregated accordingly:
+
+ .. math::
+ \ell(x, y) =
+ \begin{cases}
+ \operatorname{mean}(L), & \text{if reduction} = \text{``mean''} \\
+ \operatorname{sum}(L), & \text{if reduction} = \text{``sum''}
+ \end{cases}
+
+ where :math:`N` is the batch size.
+ """
+
+ def __init__(self, p=2, reduction="mean", relative=False):
+ """
+ Initialization of the :class:`LpLoss` class.
+
+ :param p: The order of the norm. It can be a numeric value for standard
+ p-norms or one of the following strings: ``"inf"`` for maximum
+ absolute value, ``"-inf"`` for minimum absolute value. The values
+ ``"inf"`` and ``"-inf"`` are internally converted to their floating
+ counterparts. Default is ``2``.
+ :type p: int | float | str
+ :param str reduction: The reduction method to aggregate pointwise loss
+ values. Available options include: ``"none"`` for unreduced loss,
+ ``"mean"`` for the average of the loss values, and ``"sum"`` for
+ their total sum. Default is ``"mean"``.
+ :param bool relative: If ``True``, computes the relative error.
+ Default is ``False``.
+ :raises ValueError: If ``relative`` is not a boolean.
+ :raises ValueError: If ``p`` is not a valid norm order.
+ """
+ super().__init__(reduction=reduction)
+
+ # Convert to float if inf or -inf
+ if p == "inf":
+ p = float("inf")
+ elif p == "-inf":
+ p = float("-inf")
+
+ # Check consistency
+ check_consistency(relative, bool)
+ check_consistency(p, (int, float))
+
+ # Initialize attributes
+ self.p = p
+ self.relative = relative
+
+ def forward(self, input, target):
+ """
+ Forward method of the loss function.
+
+ :param torch.Tensor input: The input tensor.
+ :param torch.Tensor target: The target tensor.
+ :return: The computed loss.
+ :rtype: torch.Tensor
+ """
+ # Compute the standard loss
+ loss = torch.linalg.norm((input - target), ord=self.p, dim=-1)
+
+ # Compute the input norm for relative error
+ if self.relative:
+ loss = loss / torch.linalg.norm(input, ord=self.p, dim=-1)
+
+ return self._reduction(loss)
diff --git a/pina/_src/loss/power_loss.py b/pina/_src/loss/power_loss.py
new file mode 100644
index 000000000..b8a0821bb
--- /dev/null
+++ b/pina/_src/loss/power_loss.py
@@ -0,0 +1,81 @@
+"""Module for the Power Loss class."""
+
+import torch
+from pina._src.loss.base_dual_loss import BaseDualLoss
+from pina._src.core.utils import check_consistency, check_positive_integer
+
+
+class PowerLoss(BaseDualLoss):
+ r"""
+ Implementation of the Power loss, measuring the pointwise averaged
+ :math:`p`-power error between an input tensor :math:`x` and a target tensor
+ :math:`y`.
+
+ Given a batch of size :math:`N` and feature dimension :math:`D`, the
+ unreduced loss (``reduction="none"``) is defined as:
+
+ .. math::
+ L = \{l_1, \dots, l_N\}^\top, \quad
+ l_n = \frac{1}{D} \sum_{i=1}^{D} \left| x_n^i - y_n^i \right|^p
+
+ If ``relative=True``, each term is normalized by the averaged
+ :math:`p`-power magnitude of the input tensor :math:`x`:
+
+ .. math::
+ l_n = \frac{\frac{1}{D} \sum_{i=1}^{D} |x_n^i - y_n^i|^p}
+ {\frac{1}{D} \sum_{i=1}^{D} |x_n^i|^p}
+
+ If ``reduction`` is set to ``"mean"`` or ``"sum"``, the vector :math:`L`
+ is aggregated accordingly:
+
+ .. math::
+ \ell(x, y) =
+ \begin{cases}
+ \operatorname{mean}(L), & \text{if reduction} = \text{``mean''} \\
+ \operatorname{sum}(L), & \text{if reduction} = \text{``sum''}
+ \end{cases}
+
+ where :math:`N` is the batch size.
+ """
+
+ def __init__(self, p=2, reduction="mean", relative=False):
+ """
+ Initialization of the :class:`PowerLoss` class.
+
+ :param int p: The order of the p-norm. Default is ``2``.
+ :param str reduction: The reduction method to aggregate pointwise loss
+ values. Available options include: ``"none"`` for unreduced loss,
+ ``"mean"`` for the average of the loss values, and ``"sum"`` for
+ their total sum. Default is ``"mean"``.
+ :param bool relative: If ``True``, computes the relative error.
+ Default is ``False``.
+ :raises ValueError: If ``relative`` is not a boolean.
+ :raises ValueError: If ``p`` is not a positive integer.
+ """
+ super().__init__(reduction=reduction)
+
+ # Check consistency
+ check_consistency(relative, bool)
+ check_positive_integer(p, strict=True)
+
+ # Initialize attributes
+ self.p = p
+ self.relative = relative
+
+ def forward(self, input, target):
+ """
+ Forward method of the loss function.
+
+ :param torch.Tensor input: The input tensor.
+ :param torch.Tensor target: The target tensor.
+ :return: The computed loss.
+ :rtype: torch.Tensor
+ """
+ # Compute the standard loss
+ loss = torch.abs((input - target)).pow(self.p).mean(-1)
+
+ # Compute the input norm for relative error
+ if self.relative:
+ loss = loss / torch.abs(input).pow(self.p).mean(-1)
+
+ return self._reduction(loss)
diff --git a/pina/_src/loss/sinkhorn_loss.py b/pina/_src/loss/sinkhorn_loss.py
new file mode 100644
index 000000000..9feddc458
--- /dev/null
+++ b/pina/_src/loss/sinkhorn_loss.py
@@ -0,0 +1,138 @@
+"""Module for the SinkhornLoss class."""
+
+import torch
+from pina._src.loss.base_dual_loss import BaseDualLoss
+from pina._src.core.utils import check_consistency, check_positive_integer
+
+
+class SinkhornLoss(BaseDualLoss):
+ r"""
+ Implementation of the Sinkhorn loss measuring the entropy-regularized
+ optimal transport distance between two empirical distributions.
+
+ Given an input tensor :math:`x` with :math:`N` samples and a target tensor
+ :math:`y` with :math:`M` samples, both in :math:`\mathbb{R}^D`, the loss is
+ defined through the entropy-regularized optimal transport problem:
+
+ .. math::
+
+ W_\varepsilon(\mu, \nu) = \min_{\pi \in \Pi(\mu, \nu)}
+ \langle C, \pi \rangle - \varepsilon H(\pi)
+
+ where :math:`\mu` and :math:`\nu` are the empirical distributions associated
+ with :math:`x` and :math:`y`, :math:`\pi` is a transport plan, and
+ :math:`\Pi(\mu, \nu)` is the set of admissible transport plans with
+ marginals :math:`\mu` and :math:`\nu`.
+
+ The cost matrix is defined as:
+
+ .. math::
+
+ C_{ij} = \left\| x_i - y_j \right\|_2^p
+
+ and the entropy term is:
+
+ .. math::
+
+ H(\pi) = - \sum_{i,j} \pi_{ij} \log \pi_{ij}
+
+ where :math:`\varepsilon > 0` controls the strength of the entropic
+ regularization.
+
+ The Sinkhorn iterations compute the optimal dual potentials :math:`f^\ast`
+ and :math:`g^\ast` in log space. The regularized optimal transport cost is
+ then recovered from the dual formulation as:
+
+ .. math::
+
+ W_\varepsilon = \langle a, f^\ast \rangle + \langle b, g^\ast \rangle
+
+ where :math:`a` and :math:`b` are uniform probability weights over the
+ :math:`N` input samples and :math:`M` target samples, respectively.
+
+ Unlike pointwise losses, the Sinkhorn loss compares whole empirical
+ distributions. Therefore, the output is always a scalar value.
+
+ Smaller values of ``eps`` provide a closer approximation to the true
+ Wasserstein distance, but may require more Sinkhorn iterations to converge.
+
+ .. seealso::
+
+ **Original reference:** Patrini, G., Carioni, M., Forr'e, P., Bhargav,
+ S., Welling, M., Van den Berg, R., Genewein, T., and Nielsen, F. (2019).
+ *Sinkhorn AutoEncoders*.
+ In Proceedings of the 35th Conference on Uncertainty in Artificial
+ Intelligence.
+ URL: ``_.
+ """
+
+ def __init__(self, p=2, eps=0.1, iterations=100):
+ """
+ Initialization of the :class:`SinkhornLoss` class.
+
+ :param int p: The exponent of the cost function. Default is ``2``.
+ :param eps: The entropy regularization strength. Smaller values provide
+ a closer approximation to the unregularized Wasserstein distance,
+ but may require more iterations for convergence. Default is ``0.1``.
+ :type eps: int | float
+ :param int iterations: The number of Sinkhorn iterations.
+ Default is ``100``.
+ :raises AssertionError: If ``iterations`` is not a positive integer.
+ :raises AssertionError: If ``p`` is not a positive integer.
+ :raises ValueError: If ``eps`` is not a positive numeric value.
+ """
+ # Initialize the base class with mean reduction
+ super().__init__(reduction="mean")
+
+ # Check consistency
+ check_positive_integer(iterations, strict=True)
+ check_positive_integer(p, strict=True)
+ check_consistency(eps, (int, float))
+ if eps <= 0:
+ raise ValueError(
+ f"Expected 'eps' to be strictly positive, but got {eps}."
+ )
+
+ # Initialize parameters
+ self.iterations = iterations
+ self.eps = eps
+ self.p = p
+
+ def forward(self, input, target):
+ """
+ Forward method of the loss function.
+
+ :param torch.Tensor input: The input tensor.
+ :param torch.Tensor target: The target tensor.
+ :return: The computed Sinkhorn loss value.
+ :rtype: torch.Tensor
+ """
+ # Extract the number of samples in input and target
+ n, m = input.shape[0], target.shape[0]
+
+ # Initialize log-uniform weights for the empirical distributions
+ log_a = -input.new_tensor(n).log().expand(n)
+ log_b = -target.new_tensor(m).log().expand(m)
+
+ # Initialize dual potentials f and g
+ f = torch.zeros(n, dtype=input.dtype, device=input.device)
+ g = torch.zeros(m, dtype=target.dtype, device=target.device)
+
+ # Define the cost matrix, shape (n, m)
+ C = torch.cdist(input, target, p=self.p) ** self.p
+
+ # Perform Sinkhorn iterations in log space for numerical stability
+ for _ in range(self.iterations):
+
+ # Update dual potential f with the softmin operation in log space
+ softmin_f = torch.logsumexp((g.unsqueeze(0) - C) / self.eps, dim=1)
+ f = self.eps * (log_a - softmin_f)
+
+ # Update dual potential g with the softmin operation in log space
+ softmin_g = torch.logsumexp((f.unsqueeze(1) - C) / self.eps, dim=0)
+ g = self.eps * (log_b - softmin_g)
+
+ # Compute the Sinkhorn loss as the sum of the means of f and g
+ loss = f.mean() + g.mean()
+
+ return self._reduction(loss.unsqueeze(0))
diff --git a/pina/_src/model/__init__.py b/pina/_src/model/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pina/model/average_neural_operator.py b/pina/_src/model/average_neural_operator.py
similarity index 96%
rename from pina/model/average_neural_operator.py
rename to pina/_src/model/average_neural_operator.py
index 6019b96c6..e16e3430f 100644
--- a/pina/model/average_neural_operator.py
+++ b/pina/_src/model/average_neural_operator.py
@@ -2,9 +2,9 @@
import torch
from torch import nn
-from .block.average_neural_operator_block import AVNOBlock
-from .kernel_neural_operator import KernelNeuralOperator
-from ..utils import check_consistency
+from pina._src.model.block.average_neural_operator_block import AVNOBlock
+from pina._src.model.kernel_neural_operator import KernelNeuralOperator
+from pina._src.core.utils import check_consistency
class AveragingNeuralOperator(KernelNeuralOperator):
diff --git a/pina/_src/model/block/__init__.py b/pina/_src/model/block/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pina/model/block/average_neural_operator_block.py b/pina/_src/model/block/average_neural_operator_block.py
similarity index 97%
rename from pina/model/block/average_neural_operator_block.py
rename to pina/_src/model/block/average_neural_operator_block.py
index 91379abeb..4b5af8081 100644
--- a/pina/model/block/average_neural_operator_block.py
+++ b/pina/_src/model/block/average_neural_operator_block.py
@@ -2,7 +2,7 @@
import torch
from torch import nn
-from ...utils import check_consistency
+from pina._src.core.utils import check_consistency
class AVNOBlock(nn.Module):
diff --git a/pina/model/block/convolution.py b/pina/_src/model/block/convolution.py
similarity index 98%
rename from pina/model/block/convolution.py
rename to pina/_src/model/block/convolution.py
index 666f66a66..bfe7054af 100644
--- a/pina/model/block/convolution.py
+++ b/pina/_src/model/block/convolution.py
@@ -2,8 +2,8 @@
from abc import ABCMeta, abstractmethod
import torch
-from .stride import Stride
-from .utils_convolution import optimizing
+from pina._src.model.block.stride import Stride
+from pina._src.model.block.utils_convolution import optimizing
class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
diff --git a/pina/model/block/convolution_2d.py b/pina/_src/model/block/convolution_2d.py
similarity index 98%
rename from pina/model/block/convolution_2d.py
rename to pina/_src/model/block/convolution_2d.py
index 825ae613b..935bb0afa 100644
--- a/pina/model/block/convolution_2d.py
+++ b/pina/_src/model/block/convolution_2d.py
@@ -1,9 +1,9 @@
-"""Module for the Continuous Convolution class."""
+"""Module for the Continuous 2D Convolution class."""
import torch
-from .convolution import BaseContinuousConv
-from .utils_convolution import check_point, map_points_
-from .integral import Integral
+from pina._src.model.block.convolution import BaseContinuousConv
+from pina._src.model.block.utils_convolution import check_point, map_points_
+from pina._src.model.block.integral import Integral
class ContinuousConvBlock(BaseContinuousConv):
diff --git a/pina/model/block/embedding.py b/pina/_src/model/block/embedding.py
similarity index 99%
rename from pina/model/block/embedding.py
rename to pina/_src/model/block/embedding.py
index 1e44ec143..f9f05c119 100644
--- a/pina/model/block/embedding.py
+++ b/pina/_src/model/block/embedding.py
@@ -1,7 +1,7 @@
"""Modules for the the Embedding blocks."""
import torch
-from pina.utils import check_consistency
+from pina._src.core.utils import check_consistency
class PeriodicBoundaryEmbedding(torch.nn.Module):
diff --git a/pina/model/block/fourier_block.py b/pina/_src/model/block/fourier_block.py
similarity index 98%
rename from pina/model/block/fourier_block.py
rename to pina/_src/model/block/fourier_block.py
index 2983c840a..2510320ec 100644
--- a/pina/model/block/fourier_block.py
+++ b/pina/_src/model/block/fourier_block.py
@@ -2,9 +2,9 @@
import torch
from torch import nn
-from ...utils import check_consistency
+from pina._src.core.utils import check_consistency
-from .spectral import (
+from pina._src.model.block.spectral import (
SpectralConvBlock1D,
SpectralConvBlock2D,
SpectralConvBlock3D,
diff --git a/pina/model/block/gno_block.py b/pina/_src/model/block/gno_block.py
similarity index 100%
rename from pina/model/block/gno_block.py
rename to pina/_src/model/block/gno_block.py
diff --git a/pina/model/block/integral.py b/pina/_src/model/block/integral.py
similarity index 100%
rename from pina/model/block/integral.py
rename to pina/_src/model/block/integral.py
diff --git a/pina/_src/model/block/kan_block.py b/pina/_src/model/block/kan_block.py
new file mode 100644
index 000000000..77597d310
--- /dev/null
+++ b/pina/_src/model/block/kan_block.py
@@ -0,0 +1,158 @@
+"""Module for the Kolmogorov-Arnold Network block."""
+
+import torch
+from pina._src.model.vectorized_spline import VectorizedSpline
+from pina._src.core.utils import check_consistency, check_positive_integer
+
+
+class KANBlock(torch.nn.Module):
+ """
+ The inner block of the Kolmogorov-Arnold Network (KAN).
+
+ The block applies a spline transformation to the input, optionally combined
+ with a linear transformation of a base activation function. The output is
+ aggregated across input dimensions to produce the final output.
+
+ .. seealso::
+
+ **Original reference**:
+ Liu Z., Wang Y., Vaidya S., Ruehle F., Halverson J., Soljacic M.,
+ Hou T., Tegmark M. (2025).
+ *KAN: Kolmogorov-Arnold Networks*.
+ DOI: `arXiv preprint arXiv:2404.19756.
+ `_
+ """
+
+ def __init__(
+ self,
+ input_dimensions,
+ output_dimensions,
+ spline_order=3,
+ n_knots=10,
+ grid_range=[0, 1],
+ base_function=torch.nn.SiLU,
+ use_base_linear=True,
+ use_bias=True,
+ init_scale_spline=1e-2,
+ init_scale_base=1.0,
+ ):
+ """
+ Initialization of the :class:`KANBlock` class.
+
+ :param int input_dimensions: The number of input features.
+ :param int output_dimensions: The number of output features.
+ :param int spline_order: The order of each spline basis function.
+ Default is 3 (cubic splines).
+ :param int n_knots: The number of knots for each spline basis function.
+ Default is 10.
+ :param grid_range: The range for the spline knots. It must be either a
+ list or a tuple of the form [min, max]. Default is [0, 1].
+ :type grid_range: list | tuple.
+ :param torch.nn.Module base_function: The base activation function to be
+ applied to the input before the linear transformation. Default is
+ :class:`torch.nn.SiLU`.
+ :param bool use_base_linear: Whether to include a linear transformation
+ of the base function output. Default is True.
+ :param bool use_bias: Whether to include a bias term in the output.
+ Default is True.
+ :param init_scale_spline: The scale for initializing each spline
+ control points. Default is 1e-2.
+ :type init_scale_spline: float | int.
+ :param init_scale_base: The scale for initializing the base linear
+ weights. Default is 1.0.
+ :type init_scale_base: float | int.
+ :raises ValueError: If ``grid_range`` is not of length 2.
+ """
+ super().__init__()
+
+ # Check consistency
+ check_consistency(base_function, torch.nn.Module, subclass=True)
+ check_positive_integer(input_dimensions, strict=True)
+ check_positive_integer(output_dimensions, strict=True)
+ check_positive_integer(spline_order, strict=True)
+ check_positive_integer(n_knots, strict=True)
+ check_consistency(use_base_linear, bool)
+ check_consistency(use_bias, bool)
+ check_consistency(init_scale_spline, (int, float))
+ check_consistency(init_scale_base, (int, float))
+ check_consistency(grid_range, (int, float))
+
+ # Raise error if grid_range is not valid
+ if len(grid_range) != 2:
+ raise ValueError("Grid must be a list or tuple with two elements.")
+
+ # Knots for the spline basis functions
+ initial_knots = torch.ones(spline_order) * grid_range[0]
+ final_knots = torch.ones(spline_order) * grid_range[1]
+
+ # Number of internal knots
+ n_internal = max(0, n_knots - 2 * spline_order)
+
+ # Internal knots are uniformly spaced in the grid range
+ internal_knots = torch.linspace(
+ grid_range[0], grid_range[1], n_internal + 2
+ )[1:-1]
+
+ # Define the knots
+ knots = torch.cat((initial_knots, internal_knots, final_knots))
+ knots = knots.unsqueeze(0).repeat(input_dimensions, 1)
+
+ # Define the control points for the spline basis functions
+ control_points = (
+ torch.randn(
+ input_dimensions,
+ output_dimensions,
+ knots.shape[-1] - spline_order,
+ )
+ * init_scale_spline
+ )
+
+ # Define the vectorized spline module
+ self.spline = VectorizedSpline(
+ order=spline_order, knots=knots, control_points=control_points
+ )
+
+ # Initialize the base function
+ self.base_function = base_function()
+
+ # Initialize the base linear weights if needed
+ if use_base_linear:
+ self.base_weight = torch.nn.Parameter(
+ torch.randn(output_dimensions, input_dimensions)
+ * (init_scale_base / (input_dimensions**0.5))
+ )
+ else:
+ self.register_parameter("base_weight", None)
+
+ # Initialize the bias term if needed
+ if use_bias:
+ self.bias = torch.nn.Parameter(torch.zeros(output_dimensions))
+ else:
+ self.register_parameter("bias", None)
+
+ def forward(self, x):
+ """
+ Forward pass of the Kolmogorov-Arnold block. The input is passed through
+ the spline transformation, optionally combined with a linear
+ transformation of the base function output, and then aggregated across
+ input dimensions to produce the final output.
+
+ :param x: The input tensor for the model.
+ :type x: torch.Tensor | LabelTensor
+ :return: The output tensor of the model.
+ :rtype: torch.Tensor | LabelTensor
+ """
+ y = self.spline(x)
+
+ if self.base_weight is not None:
+ base_x = self.base_function(x)
+ base_out = torch.einsum("bi,oi->bio", base_x, self.base_weight)
+ y = y + base_out
+
+ # aggregate contributions from all input dimensions
+ y = y.sum(dim=1)
+
+ if self.bias is not None:
+ y = y + self.bias
+
+ return y
diff --git a/pina/model/block/low_rank_block.py b/pina/_src/model/block/low_rank_block.py
similarity index 98%
rename from pina/model/block/low_rank_block.py
rename to pina/_src/model/block/low_rank_block.py
index 1e8925d95..ad67b4dca 100644
--- a/pina/model/block/low_rank_block.py
+++ b/pina/_src/model/block/low_rank_block.py
@@ -2,7 +2,7 @@
import torch
-from ...utils import check_consistency
+from pina._src.core.utils import check_consistency
class LowRankBlock(torch.nn.Module):
diff --git a/pina/_src/model/block/message_passing/__init__.py b/pina/_src/model/block/message_passing/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pina/model/block/message_passing/deep_tensor_network_block.py b/pina/_src/model/block/message_passing/deep_tensor_network_block.py
similarity index 98%
rename from pina/model/block/message_passing/deep_tensor_network_block.py
rename to pina/_src/model/block/message_passing/deep_tensor_network_block.py
index a2de3097a..ed19578b7 100644
--- a/pina/model/block/message_passing/deep_tensor_network_block.py
+++ b/pina/_src/model/block/message_passing/deep_tensor_network_block.py
@@ -2,7 +2,7 @@
import torch
from torch_geometric.nn import MessagePassing
-from ....utils import check_positive_integer
+from pina._src.core.utils import check_positive_integer
class DeepTensorNetworkBlock(MessagePassing):
diff --git a/pina/model/block/message_passing/en_equivariant_network_block.py b/pina/_src/model/block/message_passing/en_equivariant_network_block.py
similarity index 98%
rename from pina/model/block/message_passing/en_equivariant_network_block.py
rename to pina/_src/model/block/message_passing/en_equivariant_network_block.py
index b8057b0f1..28a197230 100644
--- a/pina/model/block/message_passing/en_equivariant_network_block.py
+++ b/pina/_src/model/block/message_passing/en_equivariant_network_block.py
@@ -3,8 +3,8 @@
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import degree
-from ....utils import check_positive_integer, check_consistency
-from ....model import FeedForward
+from pina._src.core.utils import check_positive_integer, check_consistency
+from pina._src.model.feed_forward import FeedForward
class EnEquivariantNetworkBlock(MessagePassing):
diff --git a/pina/model/block/message_passing/equivariant_graph_neural_operator_block.py b/pina/_src/model/block/message_passing/equivariant_graph_neural_operator_block.py
similarity index 97%
rename from pina/model/block/message_passing/equivariant_graph_neural_operator_block.py
rename to pina/_src/model/block/message_passing/equivariant_graph_neural_operator_block.py
index f6c739203..8a0f30aed 100644
--- a/pina/model/block/message_passing/equivariant_graph_neural_operator_block.py
+++ b/pina/_src/model/block/message_passing/equivariant_graph_neural_operator_block.py
@@ -1,8 +1,10 @@
"""Module for the Equivariant Graph Neural Operator block."""
import torch
-from ....utils import check_positive_integer
-from .en_equivariant_network_block import EnEquivariantNetworkBlock
+from pina._src.core.utils import check_positive_integer
+from pina._src.model.block.message_passing.en_equivariant_network_block import (
+ EnEquivariantNetworkBlock,
+)
class EquivariantGraphNeuralOperatorBlock(torch.nn.Module):
diff --git a/pina/model/block/message_passing/interaction_network_block.py b/pina/_src/model/block/message_passing/interaction_network_block.py
similarity index 98%
rename from pina/model/block/message_passing/interaction_network_block.py
rename to pina/_src/model/block/message_passing/interaction_network_block.py
index 7c6eb03f6..06fb39406 100644
--- a/pina/model/block/message_passing/interaction_network_block.py
+++ b/pina/_src/model/block/message_passing/interaction_network_block.py
@@ -2,8 +2,8 @@
import torch
from torch_geometric.nn import MessagePassing
-from ....utils import check_positive_integer
-from ....model import FeedForward
+from pina._src.core.utils import check_positive_integer
+from pina._src.model.feed_forward import FeedForward
class InteractionNetworkBlock(MessagePassing):
diff --git a/pina/model/block/message_passing/radial_field_network_block.py b/pina/_src/model/block/message_passing/radial_field_network_block.py
similarity index 97%
rename from pina/model/block/message_passing/radial_field_network_block.py
rename to pina/_src/model/block/message_passing/radial_field_network_block.py
index ef621b10e..ede0fb645 100644
--- a/pina/model/block/message_passing/radial_field_network_block.py
+++ b/pina/_src/model/block/message_passing/radial_field_network_block.py
@@ -3,8 +3,8 @@
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops
-from ....utils import check_positive_integer
-from ....model import FeedForward
+from pina._src.core.utils import check_positive_integer
+from pina._src.model.feed_forward import FeedForward
class RadialFieldNetworkBlock(MessagePassing):
diff --git a/pina/model/block/orthogonal.py b/pina/_src/model/block/orthogonal.py
similarity index 98%
rename from pina/model/block/orthogonal.py
rename to pina/_src/model/block/orthogonal.py
index cd45b3c72..24021ada6 100644
--- a/pina/model/block/orthogonal.py
+++ b/pina/_src/model/block/orthogonal.py
@@ -1,7 +1,7 @@
"""Module for the Orthogonal Block class."""
import torch
-from ...utils import check_consistency
+from pina._src.core.utils import check_consistency
class OrthogonalBlock(torch.nn.Module):
diff --git a/pina/model/block/pirate_network_block.py b/pina/_src/model/block/pirate_network_block.py
similarity index 97%
rename from pina/model/block/pirate_network_block.py
rename to pina/_src/model/block/pirate_network_block.py
index cfeb8410e..752f81901 100644
--- a/pina/model/block/pirate_network_block.py
+++ b/pina/_src/model/block/pirate_network_block.py
@@ -1,7 +1,7 @@
"""Module for the PirateNet block class."""
import torch
-from ...utils import check_consistency, check_positive_integer
+from pina._src.core.utils import check_consistency, check_positive_integer
class PirateNetBlock(torch.nn.Module):
diff --git a/pina/model/block/pod_block.py b/pina/_src/model/block/pod_block.py
similarity index 100%
rename from pina/model/block/pod_block.py
rename to pina/_src/model/block/pod_block.py
diff --git a/pina/model/block/rbf_block.py b/pina/_src/model/block/rbf_block.py
similarity index 99%
rename from pina/model/block/rbf_block.py
rename to pina/_src/model/block/rbf_block.py
index 8001381bc..061e43109 100644
--- a/pina/model/block/rbf_block.py
+++ b/pina/_src/model/block/rbf_block.py
@@ -4,7 +4,7 @@
import warnings
from itertools import combinations_with_replacement
import torch
-from ...utils import check_consistency
+from pina._src.core.utils import check_consistency
def linear(r):
diff --git a/pina/model/block/residual.py b/pina/_src/model/block/residual.py
similarity index 98%
rename from pina/model/block/residual.py
rename to pina/_src/model/block/residual.py
index f109ce03d..d1e8134cc 100644
--- a/pina/model/block/residual.py
+++ b/pina/_src/model/block/residual.py
@@ -2,7 +2,7 @@
import torch
from torch import nn
-from ...utils import check_consistency
+from pina._src.core.utils import check_consistency
class ResidualBlock(nn.Module):
diff --git a/pina/model/block/spectral.py b/pina/_src/model/block/spectral.py
similarity index 99%
rename from pina/model/block/spectral.py
rename to pina/_src/model/block/spectral.py
index aae915a42..fd5f48f6a 100644
--- a/pina/model/block/spectral.py
+++ b/pina/_src/model/block/spectral.py
@@ -2,7 +2,7 @@
import torch
from torch import nn
-from ...utils import check_consistency
+from pina._src.core.utils import check_consistency
######## 1D Spectral Convolution ###########
diff --git a/pina/model/block/stride.py b/pina/_src/model/block/stride.py
similarity index 98%
rename from pina/model/block/stride.py
rename to pina/_src/model/block/stride.py
index 2a26faf07..e802cddc0 100644
--- a/pina/model/block/stride.py
+++ b/pina/_src/model/block/stride.py
@@ -5,7 +5,7 @@
class Stride:
"""
- Stride class for continous convolution.
+ Stride class for continuous convolution.
"""
def __init__(self, dict_):
diff --git a/pina/model/block/utils_convolution.py b/pina/_src/model/block/utils_convolution.py
similarity index 100%
rename from pina/model/block/utils_convolution.py
rename to pina/_src/model/block/utils_convolution.py
diff --git a/pina/model/deeponet.py b/pina/_src/model/deeponet.py
similarity index 99%
rename from pina/model/deeponet.py
rename to pina/_src/model/deeponet.py
index c65f6b316..800f2acc3 100644
--- a/pina/model/deeponet.py
+++ b/pina/_src/model/deeponet.py
@@ -3,7 +3,7 @@
from functools import partial
import torch
from torch import nn
-from ..utils import check_consistency, is_function
+from pina._src.core.utils import check_consistency, is_function
class MIONet(torch.nn.Module):
diff --git a/pina/model/equivariant_graph_neural_operator.py b/pina/_src/model/equivariant_graph_neural_operator.py
similarity index 97%
rename from pina/model/equivariant_graph_neural_operator.py
rename to pina/_src/model/equivariant_graph_neural_operator.py
index 6b33df6db..3aa7dde69 100644
--- a/pina/model/equivariant_graph_neural_operator.py
+++ b/pina/_src/model/equivariant_graph_neural_operator.py
@@ -1,8 +1,10 @@
"""Module for the Equivariant Graph Neural Operator model."""
import torch
-from ..utils import check_positive_integer
-from .block.message_passing import EquivariantGraphNeuralOperatorBlock
+from pina._src.core.utils import check_positive_integer
+from pina._src.model.block.message_passing.equivariant_graph_neural_operator_block import (
+ EquivariantGraphNeuralOperatorBlock,
+)
class EquivariantGraphNeuralOperator(torch.nn.Module):
diff --git a/pina/model/feed_forward.py b/pina/_src/model/feed_forward.py
similarity index 99%
rename from pina/model/feed_forward.py
rename to pina/_src/model/feed_forward.py
index a1651b38b..fdf6bc91e 100644
--- a/pina/model/feed_forward.py
+++ b/pina/_src/model/feed_forward.py
@@ -2,8 +2,8 @@
import torch
from torch import nn
-from ..utils import check_consistency
-from .block.residual import EnhancedLinear
+from pina._src.core.utils import check_consistency
+from pina._src.model.block.residual import EnhancedLinear
class FeedForward(torch.nn.Module):
diff --git a/pina/model/fourier_neural_operator.py b/pina/_src/model/fourier_neural_operator.py
similarity index 97%
rename from pina/model/fourier_neural_operator.py
rename to pina/_src/model/fourier_neural_operator.py
index e1336c999..7517b39b4 100644
--- a/pina/model/fourier_neural_operator.py
+++ b/pina/_src/model/fourier_neural_operator.py
@@ -3,10 +3,14 @@
import warnings
import torch
from torch import nn
-from ..label_tensor import LabelTensor
-from ..utils import check_consistency
-from .block.fourier_block import FourierBlock1D, FourierBlock2D, FourierBlock3D
-from .kernel_neural_operator import KernelNeuralOperator
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.core.utils import check_consistency
+from pina._src.model.block.fourier_block import (
+ FourierBlock1D,
+ FourierBlock2D,
+ FourierBlock3D,
+)
+from pina._src.model.kernel_neural_operator import KernelNeuralOperator
class FourierIntegralKernel(torch.nn.Module):
diff --git a/pina/model/graph_neural_operator.py b/pina/_src/model/graph_neural_operator.py
similarity index 98%
rename from pina/model/graph_neural_operator.py
rename to pina/_src/model/graph_neural_operator.py
index 3cb5cdd31..e4d844fcb 100644
--- a/pina/model/graph_neural_operator.py
+++ b/pina/_src/model/graph_neural_operator.py
@@ -2,8 +2,8 @@
import torch
from torch.nn import Tanh
-from .block.gno_block import GNOBlock
-from .kernel_neural_operator import KernelNeuralOperator
+from pina._src.model.block.gno_block import GNOBlock
+from pina._src.model.kernel_neural_operator import KernelNeuralOperator
class GraphNeuralKernel(torch.nn.Module):
diff --git a/pina/model/kernel_neural_operator.py b/pina/_src/model/kernel_neural_operator.py
similarity index 99%
rename from pina/model/kernel_neural_operator.py
rename to pina/_src/model/kernel_neural_operator.py
index e3cb790e5..81d1be45d 100644
--- a/pina/model/kernel_neural_operator.py
+++ b/pina/_src/model/kernel_neural_operator.py
@@ -1,7 +1,7 @@
"""Module for the Kernel Neural Operator model class."""
import torch
-from ..utils import check_consistency
+from pina._src.core.utils import check_consistency
class KernelNeuralOperator(torch.nn.Module):
diff --git a/pina/_src/model/kolmogorov_arnold_network.py b/pina/_src/model/kolmogorov_arnold_network.py
new file mode 100644
index 000000000..1782aab4b
--- /dev/null
+++ b/pina/_src/model/kolmogorov_arnold_network.py
@@ -0,0 +1,105 @@
+import torch
+from pina._src.model.block.kan_block import KANBlock
+from pina._src.core.utils import check_consistency
+
+
+class KolmogorovArnoldNetwork(torch.nn.Module):
+ """
+ Implementation of Kolmogorov-Arnold Network (KAN).
+
+ The model consists of a sequence of KAN blocks, where each block applies a
+ spline transformation to the input, optionally combined with a linear
+ transformation of a base activation function.
+
+ .. seealso::
+
+ **Original reference**:
+ Liu Z., Wang Y., Vaidya S., Ruehle F., Halverson J., Soljacic M.,
+ Hou T., Tegmark M. (2025).
+ *KAN: Kolmogorov-Arnold Networks*.
+ DOI: `arXiv preprint arXiv:2404.19756.
+ `_
+ """
+
+ def __init__(
+ self,
+ layers,
+ spline_order=3,
+ n_knots=10,
+ grid_range=[-1, 1],
+ base_function=torch.nn.SiLU,
+ use_base_linear=True,
+ use_bias=True,
+ init_scale_spline=1e-2,
+ init_scale_base=1.0,
+ ):
+ """
+ Initialization of the :class:`KolmogorovArnoldNetwork` class.
+
+ :param layers: A list of integers specifying the sizes of each layer,
+ including input and output dimensions.
+ :type layers: list | tuple.
+ :param int spline_order: The order of each spline basis function.
+ Default is 3 (cubic splines).
+ :param int n_knots: The number of knots for each spline basis function.
+ Default is 3.
+ :param grid_range: The range for the spline knots. It must be either a
+ list or a tuple of the form [min, max]. Default is [0, 1].
+ :type grid_range: list | tuple.
+ :param torch.nn.Module base_function: The base activation function to be
+ applied to the input before the linear transformation. Default is
+ :class:`torch.nn.SiLU`.
+ :param bool use_base_linear: Whether to include a linear transformation
+ of the base function output. Default is True.
+ :param bool use_bias: Whether to include a bias term in the output.
+ Default is True.
+ :param init_scale_spline: The scale for initializing each spline
+ control points. Default is 1e-2.
+ :type init_scale_spline: float | int.
+ :param init_scale_base: The scale for initializing the base linear
+ weights. Default is 1.0.
+ :type init_scale_base: float | int.
+ :raises ValueError: If ``grid_range`` is not of length 2.
+ """
+ super().__init__()
+
+ # Check consistency -- all other checks are performed in KANBlock
+ check_consistency(layers, int)
+ if len(layers) < 2:
+ raise ValueError(
+ "`Provide at least two elements for layers (input and output)."
+ )
+
+ # Initialize KAN blocks
+ self.kan_layers = torch.nn.ModuleList(
+ [
+ KANBlock(
+ input_dimensions=layers[i],
+ output_dimensions=layers[i + 1],
+ spline_order=spline_order,
+ n_knots=n_knots,
+ grid_range=grid_range,
+ base_function=base_function,
+ use_base_linear=use_base_linear,
+ use_bias=use_bias,
+ init_scale_spline=init_scale_spline,
+ init_scale_base=init_scale_base,
+ )
+ for i in range(len(layers) - 1)
+ ]
+ )
+
+ def forward(self, x):
+ """
+ Forward pass of the KolmogorovArnoldNetwork model. It passes the input
+ through each KAN block in the network and returns the final output.
+
+ :param x: The input tensor for the model.
+ :type x: torch.Tensor | LabelTensor
+ :return: The output tensor of the model.
+ :rtype: torch.Tensor | LabelTensor
+ """
+ for layer in self.kan_layers:
+ x = layer(x)
+
+ return x
diff --git a/pina/model/low_rank_neural_operator.py b/pina/_src/model/low_rank_neural_operator.py
similarity index 97%
rename from pina/model/low_rank_neural_operator.py
rename to pina/_src/model/low_rank_neural_operator.py
index 1a7082dff..049894001 100644
--- a/pina/model/low_rank_neural_operator.py
+++ b/pina/_src/model/low_rank_neural_operator.py
@@ -3,10 +3,10 @@
import torch
from torch import nn
-from ..utils import check_consistency
+from pina._src.core.utils import check_consistency
-from .kernel_neural_operator import KernelNeuralOperator
-from .block.low_rank_block import LowRankBlock
+from pina._src.model.kernel_neural_operator import KernelNeuralOperator
+from pina._src.model.block.low_rank_block import LowRankBlock
class LowRankNeuralOperator(KernelNeuralOperator):
diff --git a/pina/model/multi_feed_forward.py b/pina/_src/model/multi_feed_forward.py
similarity index 95%
rename from pina/model/multi_feed_forward.py
rename to pina/_src/model/multi_feed_forward.py
index f2f149ca6..df8fb19e2 100644
--- a/pina/model/multi_feed_forward.py
+++ b/pina/_src/model/multi_feed_forward.py
@@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
import torch
-from .feed_forward import FeedForward
+from pina._src.model.feed_forward import FeedForward
class MultiFeedForward(torch.nn.Module, ABC):
diff --git a/pina/model/pirate_network.py b/pina/_src/model/pirate_network.py
similarity index 95%
rename from pina/model/pirate_network.py
rename to pina/_src/model/pirate_network.py
index 96102b41f..09aad269d 100644
--- a/pina/model/pirate_network.py
+++ b/pina/_src/model/pirate_network.py
@@ -1,8 +1,9 @@
"""Module for the PirateNet model class."""
import torch
-from .block import FourierFeatureEmbedding, PirateNetBlock
-from ..utils import check_consistency, check_positive_integer
+from pina._src.model.block.embedding import FourierFeatureEmbedding
+from pina._src.model.block.pirate_network_block import PirateNetBlock
+from pina._src.core.utils import check_consistency, check_positive_integer
class PirateNet(torch.nn.Module):
diff --git a/pina/model/sindy.py b/pina/_src/model/sindy.py
similarity index 97%
rename from pina/model/sindy.py
rename to pina/_src/model/sindy.py
index a40fa37b4..f69842a54 100644
--- a/pina/model/sindy.py
+++ b/pina/_src/model/sindy.py
@@ -2,7 +2,7 @@
from typing import Callable
import torch
-from ..utils import check_consistency, check_positive_integer
+from pina._src.core.utils import check_consistency, check_positive_integer
class SINDy(torch.nn.Module):
diff --git a/pina/model/spline.py b/pina/_src/model/spline.py
similarity index 98%
rename from pina/model/spline.py
rename to pina/_src/model/spline.py
index d9141fe8c..ed7f74678 100644
--- a/pina/model/spline.py
+++ b/pina/_src/model/spline.py
@@ -2,7 +2,7 @@
import warnings
import torch
-from ..utils import check_positive_integer, check_consistency
+from pina._src.core.utils import check_consistency, check_positive_integer
class Spline(torch.nn.Module):
@@ -202,7 +202,7 @@ def basis(self, x, collection=False):
:param torch.Tensor x: The points to be evaluated.
:param bool collection: If True, returns a list of basis functions for
all orders up to the spline order. Default is False.
- :raise ValueError: If ``collection`` is not a boolean.
+ :raises ValueError: If ``collection`` is not a boolean.
:return: The basis functions evaluated at x.
:rtype: torch.Tensor | list[torch.Tensor]
"""
@@ -290,7 +290,7 @@ def derivative(self, x, degree):
:param x: The input tensor.
:type x: torch.Tensor | LabelTensor
:param int degree: The derivative degree to compute.
- :raise ValueError: If ``degree`` is not an integer.
+ :raises ValueError: If ``degree`` is not an integer.
:return: The derivative tensor.
:rtype: torch.Tensor
"""
diff --git a/pina/model/spline_surface.py b/pina/_src/model/spline_surface.py
similarity index 97%
rename from pina/model/spline_surface.py
rename to pina/_src/model/spline_surface.py
index 767e5b0dc..5550d761d 100644
--- a/pina/model/spline_surface.py
+++ b/pina/_src/model/spline_surface.py
@@ -2,8 +2,8 @@
import torch
from .spline import Spline
-from ..label_tensor import LabelTensor
-from ..utils import check_consistency, check_positive_integer
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.core.utils import check_consistency, check_positive_integer
class SplineSurface(torch.nn.Module):
@@ -134,8 +134,8 @@ def derivative(self, x, degree_u, degree_v):
parameter direction.
:param int degree_v: The degree of the derivative along the second
parameter direction.
- :raise ValueError: If ``degree_u`` is not an integer.
- :raise ValueError: If ``degree_v`` is not an integer.
+ :raises ValueError: If ``degree_u`` is not an integer.
+ :raises ValueError: If ``degree_v`` is not an integer.
:return: The derivative tensor.
:rtype: torch.Tensor
"""
diff --git a/pina/_src/model/vectorized_spline.py b/pina/_src/model/vectorized_spline.py
new file mode 100644
index 000000000..1dfe323e6
--- /dev/null
+++ b/pina/_src/model/vectorized_spline.py
@@ -0,0 +1,652 @@
+"""Vectorized univariate B-spline model with per-spline knots."""
+
+import warnings
+import torch
+from pina._src.core.utils import check_consistency, check_positive_integer
+
+
+class VectorizedSpline(torch.nn.Module):
+ r"""
+ The vectorized B-spline model class.
+
+ A :class:`VectorizedSpline` represents a vector spline, i.e., a collection
+ of independent univariate B-splines evaluated in parallel. Each univariate
+ spline has its own knot vector and its own control points, and acts on one
+ input feature.
+
+ Given ``s`` univariate splines, the vector spline maps an input
+ :math:`x = (x^{(1)}, \dots, x^{(s)}) \in \mathbb{R}^s` to an output obtained
+ by evaluating each univariate spline on its corresponding scalar input
+ :math:`x^{(j)}`.
+
+ For the :math:`j`-th univariate spline of order :math:`k`, the output is
+ defined as
+
+ .. math::
+
+ S^{(j)}(x^{(j)}) = \sum_{i=1}^{n_j} B_{i,k}^{(j)}(x^{(j)}) C_i^{(j)},
+
+ where:
+
+ - :math:`C^{(j)}` are the control points of the :math:`j`-th univariate
+ spline. In the scalar-output case, :math:`C^{(j)} \in \mathbb{R}^{n_j}`.
+ More generally, each univariate spline may have output dimension
+ :math:`o`, so :math:`C^{(j)} \in \mathbb{R}^{o \times n_j}`.
+ - :math:`B_{i,k}^{(j)}(x)` are the B-spline basis functions of order
+ :math:`k`, i.e., piecewise polynomials of degree :math:`k-1`, associated
+ with the knot vector of the :math:`j`-th univariate spline.
+ - :math:`X^{(j)} = \{x_1^{(j)}, x_2^{(j)}, \dots, x_{m_j}^{(j)}\}` is the
+ non-decreasing knot vector of the :math:`j`-th univariate spline.
+
+ If the first and last knots of a given univariate spline are repeated
+ :math:`k` times, then that univariate spline interpolates its first and last
+ control points.
+
+ The full vector spline evaluates all univariate splines in parallel. If each
+ univariate spline has output dimension :math:`o`, then before optional
+ aggregation the output has shape ``[batch, s, o]``.
+
+ .. note::
+
+ Each univariate spline is forced to be zero outside the interval defined
+ by the first and last knots of its own knot vector.
+
+ .. note::
+
+ This class does not represent a single multivariate spline
+ :math:`\mathbb{R}^s \to \mathbb{R}^o` with a genuinely multivariate
+ basis. Instead, it represents a vector of splines built from ``s``
+ independent univariate splines, one for each input feature.
+
+ .. note::
+
+ When using the :meth:`derivative` method of this class, derivatives are
+ computed directly in vectorized form and returned with the correct
+ shape. In contrast, when relying on ``autograd``, derivatives must be
+ computed separately for each output dimension of each univariate spline
+ and then combined, since autograd does not natively handle this
+ vectorized structure.
+
+ :Example:
+
+ >>> from pina.model import VectorizedSpline
+ >>> import torch
+
+ >>> knt1 = torch.tensor([
+ ... [0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0],
+ ... [0.0, 0.0, 0.0, 0.5, 1.0, 1.0, 1.0],
+ ... ])
+ >>> spline1 = VectorizedSpline(order=3, knots=knt1, control_points=None)
+
+ >>> knt2 = {"n": 7, "min": 0.0, "max": 2.0, "mode": "auto", "n_splines": 2}
+ >>> spline2 = VectorizedSpline(order=3, knots=knt2, control_points=None)
+
+ >>> knt3 = torch.tensor([
+ ... [0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0],
+ ... [0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0],
+ ... ])
+ >>> ctrl3 = torch.tensor([
+ ... [0.0, 1.0, 3.0, 2.0],
+ ... [1.0, 0.0, 2.0, 1.0],
+ ... ])
+ >>> spline3 = VectorizedSpline(order=3, knots=knt3, control_points=ctrl3)
+ """
+
+ def __init__(
+ self,
+ order=4,
+ knots=None,
+ control_points=None,
+ aggregate_output=None,
+ ):
+ """
+ Initialization of the :class:`VectorizedSpline` class.
+
+ :param int order: The order of each univariate spline. The corresponding
+ basis functions are polynomials of degree ``order - 1``.
+ Default is 4.
+ :param knots: The knots of the spline. If a tensor is provided, it must
+ have shape ``[s, n]``, where ``s`` is the number of univariate
+ splines and ``n`` is the number of knots per univariate spline. If a
+ dictionary is provided, it must contain the keys ``"n"``, ``"min"``,
+ ``"max"``, ``"mode"``, and ``"n_splines"``. Here, ``"n"`` specifies
+ the number of knots for each univariate spline, ``"min"`` and
+ ``"max"`` define the interval, ``"mode"`` selects the sampling
+ strategy, and ``"n_splines"`` specifies the number of univariate
+ splines. The supported modes are ``"uniform"``, where the knots are
+ evenly spaced over :math:`[min, max]`, and ``"auto"``, where knots
+ are constructed to ensure that each univariate spline interpolates
+ the first and last control points. In this case, the number of knots
+ is adjusted if :math:`n < 2 * order`. If None is given, knots are
+ initialized automatically over :math:`[0, 1]` ensuring interpolation
+ of the first and last control points. Default is None.
+ :type knots: torch.Tensor | dict
+ :param torch.Tensor control_points: The control points tensor. The
+ tensor must be either of shape ``[s, o, c]`` or ``[s, c]``, where
+ each univariate spline has ``c`` control points and output dimension
+ ``o``. In the latter case, the control points are expanded to shape
+ ``[s, 1, c]``. If None, control points are initialized to learnable
+ parameters with zero initial value. Default is None.
+ :param str aggregate_output: If None, the output of each univariate
+ spline is returned separately, resulting in an output of shape
+ ``[batch, s, o]``, where ``s`` is the number of univariate splines
+ and ``o`` is the output dimension of each univariate spline. If set
+ to ``"mean"`` or ``"sum"``, the output is aggregated accordingly
+ across the last dimension, resulting in an output of shape
+ ``[batch, s]``. Default is None.
+ :raises AssertionError: If ``order`` is not a positive integer.
+ :raises ValueError: If ``knots`` is neither a torch.Tensor nor a
+ dictionary, when provided.
+ :raises ValueError: If ``aggregate_output`` is not None, "mean", or
+ "sum".
+ :raises ValueError: If ``control_points`` is not a torch.Tensor,
+ when provided.
+ :raises ValueError: If both ``knots`` and ``control_points`` are None.
+ :raises ValueError: If ``knots`` is not two-dimensional, after
+ processing.
+ :raises ValueError: If ``control_points``, after expansion when
+ two-dimensional, is not three-dimensional.
+ :raises ValueError: If, for each univariate spline, the number of
+ ``knots`` is not equal to the sum of ``order`` and the number of
+ ``control_points.``
+ :raises UserWarning: If, for each univariate spline, the number of
+ ``control_points`` is lower than the ``order``, resulting in a
+ degenerate spline.
+ :raises ValueError: If the number of univariate splines in ``knots`` and
+ ``control_points`` do not match.
+ """
+
+ super().__init__()
+
+ # Check consistency
+ check_positive_integer(value=order, strict=True)
+ check_consistency(knots, (type(None), torch.Tensor, dict))
+ check_consistency(control_points, (type(None), torch.Tensor))
+
+ # Raise error if neither knots nor control points are provided
+ if knots is None and control_points is None:
+ raise ValueError("knots and control_points cannot both be None.")
+
+ # Raise error if aggregate_output is not None, "mean", or "sum"
+ if aggregate_output not in (None, "mean", "sum"):
+ raise ValueError(
+ f"aggregate_output must be None, 'mean', or 'sum'."
+ f" Got {aggregate_output}."
+ )
+
+ # Initialize knots if not provided
+ if knots is None and control_points is not None:
+ knots = {
+ "n": control_points.shape[-1] + order,
+ "min": 0,
+ "max": 1,
+ "n_splines": control_points.shape[0],
+ "mode": "auto",
+ }
+
+ # Initialization - knots and control points managed by their setters
+ self.order = order
+ self.knots = knots
+ self.control_points = control_points
+ self.aggregate_output = aggregate_output
+
+ # Check dimensionality of control points
+ if self.control_points.ndim != 3:
+ raise ValueError("control_points must be three-dimensional.")
+
+ # Raise error if #knots != order + #control_points
+ if self.knots.shape[-1] != self.order + self.control_points.shape[-1]:
+ raise ValueError(
+ f" The number of knots per spline must be equal to order + the"
+ f" number of control points. Got {self.knots.shape[-1]} knots"
+ f" per spline, {self.control_points.shape[-1]} control points,"
+ f" and {self.order} order."
+ )
+
+ # Raise warning if spline is degenerate
+ if self.control_points.shape[-1] < self.order:
+ warnings.warn(
+ "The number of control points per spline is smaller than the"
+ " spline order. This creates a degenerate spline with limited"
+ " flexibility.",
+ UserWarning,
+ )
+
+ # Raise error if knots and control points have different # of splines
+ if self.knots.shape[0] != self.control_points.shape[0]:
+ raise ValueError(
+ f"The number of splines must be the same for knots and"
+ f" control points. Got {self.knots.shape[0]} splines for knots"
+ f" and {self.control_points.shape[0]} splines for control"
+ f" points."
+ )
+
+ # Precompute boundary interval index
+ self.register_buffer(
+ "_boundary_interval_idx", self._compute_boundary_interval()
+ )
+
+ # Precompute denominators used in derivative formulas
+ self._compute_derivative_denominators()
+
+ def _compute_boundary_interval(self):
+ """
+ Precompute the index of the rightmost non-degenerate interval to improve
+ performance, eliminating the need to perform a search loop in the basis
+ function on each call.
+
+ :return: The index of the rightmost non-degenerate interval for each
+ univariate spline.
+ :rtype: torch.Tensor
+ """
+ # Compute the differences between consecutive knots for each spline
+ diffs = self._knots[:, 1:] - self._knots[:, :-1]
+ valid = diffs > 0
+
+ # Initialize idx tensor to store the last valid interval for each spline
+ idx = torch.zeros(
+ self._knots.shape[0], dtype=torch.long, device=self._knots.device
+ )
+
+ # For each spline, find the last idx where interval is non-degenerate
+ for s in range(self._knots.shape[0]):
+ valid_s = torch.nonzero(valid[s], as_tuple=False)
+ idx[s] = valid_s[-1, 0] if valid_s.numel() > 0 else 0
+
+ return idx
+
+ def _compute_derivative_denominators(self):
+ """
+ Precompute the denominators used in the derivatives for all orders up to
+ the spline order to avoid redundant calculations.
+ """
+ # Precompute for order 2 to k
+ for i in range(2, self.order + 1):
+
+ # Denominators for the derivative recurrence relations
+ left_den = self.knots[:, i - 1 : -1] - self.knots[:, :-i]
+ right_den = self.knots[:, i:] - self.knots[:, 1 : -i + 1]
+
+ # If consecutive knots are equal, set left and right factors to zero
+ left_fac = torch.where(
+ torch.abs(left_den) > 1e-10,
+ (i - 1) / left_den,
+ torch.zeros_like(left_den),
+ )
+ right_fac = torch.where(
+ torch.abs(right_den) > 1e-10,
+ (i - 1) / right_den,
+ torch.zeros_like(right_den),
+ )
+
+ # Register buffers
+ self.register_buffer(f"_left_factor_order_{i}", left_fac)
+ self.register_buffer(f"_right_factor_order_{i}", right_fac)
+
+ def basis(self, x, collection=False):
+ """
+ Evaluate the B-spline basis functions for each univariate spline.
+
+ This method applies the Cox-de Boor recursion in vectorized form across
+ all univariate splines of the vector spline.
+
+ :param torch.Tensor x: The points to be evaluated.
+ :param bool collection: If True, returns a list of basis functions for
+ all orders up to the spline order. Default is False.
+ :raises ValueError: If ``collection`` is not a boolean.
+ :raises ValueError: If ``x`` is not two-dimensional.
+ :raises ValueError: If the number of input features does not match
+ the number of univariate splines.
+ :return: The basis functions evaluated at x.
+ :rtype: torch.Tensor
+ """
+ # Check consistency
+ check_consistency(collection, bool)
+
+ # Ensure x is a tensor of the same dtype as knots
+ x = x.as_subclass(torch.Tensor).to(dtype=self.knots.dtype)
+
+ # Raise error if x does not have shape (batch, s)
+ if x.ndim != 2:
+ raise ValueError(
+ f"The input must have shape (batch, s). Got {x.shape}."
+ )
+
+ # Raise error if x has different number of splines than knots
+ if x.shape[1] != self.knots.shape[0]:
+ raise ValueError(
+ f"The number of input features must be the same as the number"
+ f" of univariate splines. Got {x.shape[1]} input features,"
+ f" but {self.knots.shape[0]} univariate splines."
+ )
+
+ # Add a final dimension to x for broadcasting
+ x = x.unsqueeze(-1)
+
+ # Add an initial dimension to knots for broadcasting
+ knots = self.knots.unsqueeze(0)
+
+ # Base case of recursion: indicator functions for the intervals
+ basis = (x >= knots[..., :-1]) & (x < knots[..., 1:])
+ basis = basis.to(x.dtype)
+
+ # Extract left and right knots of the boundary interval for each spline
+ range_tensor = torch.arange(self.knots.shape[0], device=x.device)
+ knot_left = self.knots[range_tensor, self._boundary_interval_idx]
+ knot_right = self.knots[range_tensor, self._boundary_interval_idx + 1]
+
+ # Identify points at the rightmost boundary
+ at_rightmost_boundary = (
+ x.squeeze(-1) >= knot_left.unsqueeze(0)
+ ) & torch.isclose(
+ x.squeeze(-1), knot_right.unsqueeze(0), rtol=1e-8, atol=1e-10
+ )
+
+ # Ensure the correct value is set at the rightmost boundary
+ if torch.any(at_rightmost_boundary):
+ b_idx, s_idx = torch.nonzero(at_rightmost_boundary, as_tuple=True)
+ basis[b_idx, s_idx, self._boundary_interval_idx[s_idx]] = 1.0
+
+ # If returning the whole collection, initialize list
+ if collection:
+ basis_collection = [None, basis]
+
+ # Cox-de Boor recursion -- iterative case
+ for i in range(1, self.order):
+
+ # Compute the denominators for both terms of the recursion
+ denom1 = knots[..., i:-1] - knots[..., : -(i + 1)]
+ denom2 = knots[..., i + 1 :] - knots[..., 1:-i]
+
+ # Ensure no division by zero
+ denom1 = torch.where(
+ denom1.abs() < 1e-8, torch.ones_like(denom1), denom1
+ )
+ denom2 = torch.where(
+ denom2.abs() < 1e-8, torch.ones_like(denom2), denom2
+ )
+
+ # Compute the two terms of the recursion
+ term1 = ((x - knots[..., : -(i + 1)]) / denom1) * basis[..., :-1]
+ term2 = ((knots[..., i + 1 :] - x) / denom2) * basis[..., 1:]
+
+ # Combine terms to get the new basis
+ basis = term1 + term2
+
+ if collection:
+ basis_collection.append(basis)
+
+ return basis_collection if collection else basis
+
+ def forward(self, x):
+ """
+ Forward pass for the :class:`VectorizedSpline` model. Each univariate
+ spline is evaluated independently on its corresponding input feature.
+
+ The input is expected to have shape ``[batch, s]``, where ``s`` is the
+ number of univariate splines. The output has shape ``[batch, s, o]``,
+ where ``o`` is the output dimension of each univariate spline, unless an
+ aggregation method is specified. If both ``s`` and ``o`` are 1, the
+ output is aggregated across the last dimension, resulting in an output
+ of shape ``[batch, s]``. If ``aggregate_output`` is set to ``"mean"`` or
+ ``"sum"``, the output is aggregated across the last dimension, resulting
+ in an output of shape ``[batch, s]``.
+
+ :param x: The input tensor.
+ :type x: torch.Tensor | LabelTensor
+ :return: The output tensor.
+ :rtype: torch.Tensor
+ """
+ # Compute the basis functions at x
+ basis = self.basis(x)
+
+ # Compute the output for each spline
+ out = torch.einsum("bsc,soc->bso", basis, self.control_points)
+
+ # Aggregate output if needed
+ if self.aggregate_output == "mean":
+ out = out.mean(dim=-1)
+ elif self.aggregate_output == "sum":
+ out = out.sum(dim=-1)
+ elif out.shape[1] == 1 and out.shape[2] == 1:
+ out = out.squeeze(-1)
+
+ return out
+
+ def derivative(self, x, degree):
+ """
+ Compute the ``degree``-th derivative of each univariate spline at the
+ given input points.
+
+ The output has shape ``[batch, s, o]``, where ``o`` is the output
+ dimension of each univariate spline, unless an aggregation method is
+ specified. If both ``s`` and ``o`` are 1, the output is aggregated
+ across the last dimension, resulting in an output of shape
+ ``[batch, s]``. If ``aggregate_output`` is set to ``"mean"`` or
+ ``"sum"``, the output is aggregated across the last dimension, resulting
+ in an output of shape ``[batch, s]``.
+
+ :param x: The input tensor.
+ :type x: torch.Tensor | LabelTensor
+ :param int degree: The derivative degree to compute.
+ :return: The derivative tensor.
+ :rtype: torch.Tensor
+ """
+ # Check consistency
+ check_positive_integer(degree, strict=False)
+
+ # Compute basis derivative
+ der = self._basis_derivative(x.as_subclass(torch.Tensor), degree=degree)
+
+ # Compute the output for each spline
+ out = torch.einsum("bsc,soc->bso", der, self.control_points)
+
+ # Aggregate output if needed
+ if self.aggregate_output == "mean":
+ out = out.mean(dim=-1)
+ elif self.aggregate_output == "sum":
+ out = out.sum(dim=-1)
+ elif out.shape[1] == 1 and out.shape[2] == 1:
+ out = out.squeeze(-1)
+
+ return out
+
+ def _basis_derivative(self, x, degree):
+ """
+ Compute the ``degree``-th derivative of the vectorized spline basis
+ functions at the given input points using an iterative approach.
+
+ :param torch.Tensor x: The points to be evaluated.
+ :param int degree: The derivative degree to compute.
+ :return: The derivative of the basis functions of order ``self.order``.
+ :rtype: torch.Tensor
+ """
+ # Compute the whole basis collection
+ basis = self.basis(x, collection=True)
+
+ # Derivatives initialization (dummy at index 0 for convenience)
+ derivatives = [None] + [basis[o] for o in range(1, self.order + 1)]
+
+ # Iterate over derivative degrees
+ for _ in range(1, degree + 1):
+
+ # Current degree derivatives (with dummy at index 0 for convenience)
+ current_der = [None] * (self.order + 1)
+ current_der[1] = torch.zeros_like(derivatives[1])
+
+ # Iterate over basis orders
+ for o in range(2, self.order + 1):
+
+ # Retrieve precomputed factors
+ left_fac = getattr(self, f"_left_factor_order_{o}")
+ right_fac = getattr(self, f"_right_factor_order_{o}")
+
+ # derivatives[o - 1] has shape [b, s, m]
+ # Slice previous derivatives to align
+ left_part = derivatives[o - 1][..., :-1]
+ right_part = derivatives[o - 1][..., 1:]
+
+ # Broadcast factors over batch dims
+ left_fac = left_fac.unsqueeze(0)
+ right_fac = right_fac.unsqueeze(0)
+
+ # Compute current derivatives
+ current_der[o] = left_fac * left_part - right_fac * right_part
+
+ # Update derivatives for next degree
+ derivatives = current_der
+
+ return derivatives[self.order]
+
+ @property
+ def control_points(self):
+ """
+ The control points of the spline.
+
+ :return: The control points.
+ :rtype: torch.Tensor
+ """
+ return self._control_points
+
+ @control_points.setter
+ def control_points(self, control_points):
+ """
+ Set the control points of the spline.
+
+ :param torch.Tensor control_points: The control points tensor. The
+ tensor must be either of shape ``[s, o, c]`` or ``[s, c]``, where
+ each univariate spline has ``c`` control points and output dimension
+ ``o``. In the latter case, the control points are expanded to shape
+ ``[s, 1, c]``.
+ :raises ValueError: If there are not enough knots to define the control
+ points, due to the relation: #knots = order + #control_points.
+ """
+ # If control points are not provided, initialize them
+ if control_points is None:
+
+ # Check that there are enough knots to define control points
+ if self.knots.shape[-1] < self.order + 1:
+ raise ValueError(
+ f"Not enough knots to define control points. Got"
+ f" {self.knots.shape[-1]} knots for each univariate spline,"
+ f" but need at least {self.order + 1}."
+ )
+
+ # Initialize control points to zero
+ control_points = torch.zeros(
+ self.knots.shape[0], 1, self.knots.shape[-1] - self.order
+ )
+
+ # If a the control points are 2D, add an output dimension of size 1
+ if control_points.ndim == 2:
+ control_points = control_points.unsqueeze(1)
+
+ # Set control points
+ self._control_points = torch.nn.Parameter(
+ control_points, requires_grad=True
+ )
+
+ @property
+ def knots(self):
+ """
+ The knots of the spline.
+
+ :return: The knots.
+ :rtype: torch.Tensor
+ """
+ return self._knots
+
+ @knots.setter
+ def knots(self, value):
+ """
+ Set the knots of the spline.
+ :param value: The knots of the spline. If a tensor is provided, it must
+ have shape ``[s, n]``, where ``s`` is the number of univariate
+ splines and ``n`` is the number of knots per univariate spline. If a
+ dictionary is provided, it must contain the keys ``"n"``, ``"min"``,
+ ``"max"``, ``"mode"``, and ``"n_splines"``. Here, ``"n"`` specifies
+ the number of knots for each univariate spline, ``"min"`` and
+ ``"max"`` define the interval, ``"mode"`` selects the sampling
+ strategy, and ``"n_splines"`` specifies the number of univariate
+ splines. The supported modes are ``"uniform"``, where the knots are
+ evenly spaced over :math:`[min, max]`, and ``"auto"``, where knots
+ are constructed to ensure that each univariate spline interpolates
+ the first and last control points. In this case, the number of knots
+ is adjusted if :math:`n < 2 * order`. If None is given, knots are
+ initialized automatically over :math:`[0, 1]` ensuring interpolation
+ of the first and last control points.
+ :type value: torch.Tensor | dict
+ :raises ValueError: If a dictionary is provided but does not contain
+ the required keys.
+ :raises ValueError: If the mode specified in the dictionary is invalid.
+ :raises ValueError: If knots is not two-dimensional after processing.
+ """
+ # If a dictionary is provided, initialize knots accordingly
+ if isinstance(value, dict):
+
+ # Check that required keys are present
+ required_keys = {"n", "min", "max", "mode", "n_splines"}
+ if not required_keys.issubset(value.keys()):
+ raise ValueError(
+ f"When providing knots as a dictionary, the following "
+ f"keys must be present: {required_keys}. Got "
+ f"{value.keys()}."
+ )
+
+ # Save number of splines for later use
+ n_splines = value["n_splines"]
+
+ # Uniform sampling of knots
+ if value["mode"] == "uniform":
+ value = torch.linspace(value["min"], value["max"], value["n"])
+
+ # Automatic sampling of interpolating knots
+ elif value["mode"] == "auto":
+
+ # Repeat the first and last knots 'order' times
+ initial_knots = torch.ones(self.order) * value["min"]
+ final_knots = torch.ones(self.order) * value["max"]
+
+ # Number of internal knots
+ n_internal = value["n"] - 2 * self.order
+
+ # If no internal knots are needed, just concatenate boundaries
+ if n_internal <= 0:
+ value = torch.cat((initial_knots, final_knots))
+
+ # Else, sample internal knots uniformly and exclude boundaries
+ # Recover the correct number of internal knots when slicing by
+ # adding 2 to n_internal
+ else:
+ internal_knots = torch.linspace(
+ value["min"], value["max"], n_internal + 2
+ )[1:-1]
+ value = torch.cat(
+ (initial_knots, internal_knots, final_knots)
+ )
+
+ # Raise error if mode is invalid
+ else:
+ raise ValueError(
+ f"Invalid mode for knots initialization. Got "
+ f"{value['mode']}, but expected 'uniform' or 'auto'."
+ )
+
+ # Repeat the knot vector for each spline
+ value = value.unsqueeze(0).repeat(n_splines, 1)
+
+ # Set knots
+ self.register_buffer("_knots", value.sort(dim=-1).values)
+
+ # Check dimensionality of knots
+ if self.knots.ndim != 2:
+ raise ValueError("knots must be two-dimensional.")
+
+ # Recompute boundary interval when knots change
+ if hasattr(self, "_boundary_interval_idx"):
+ self.register_buffer(
+ "_boundary_interval_idx", self._compute_boundary_interval()
+ )
+
+ # Recompute derivative denominators when knots change
+ self._compute_derivative_denominators()
diff --git a/pina/_src/optim/__init__.py b/pina/_src/optim/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pina/_src/optim/optimizer_interface.py b/pina/_src/optim/optimizer_interface.py
new file mode 100644
index 000000000..b60e23624
--- /dev/null
+++ b/pina/_src/optim/optimizer_interface.py
@@ -0,0 +1,30 @@
+"""Module for the Optimizer Interface."""
+
+from abc import ABCMeta, abstractmethod
+
+
+class OptimizerInterface(metaclass=ABCMeta):
+ """
+ Abstract interface for all optimizers.
+ """
+
+ @abstractmethod
+ def hook(self, parameters):
+ """
+ Execute custom logic associated with the optimizer instance.
+
+ This method is intended to encapsulate any additional behavior that
+ should be triggered during the optimization process.
+
+ :param dict parameters: The parameters of the model to be optimized.
+ """
+
+ @property
+ @abstractmethod
+ def instance(self):
+ """
+ The underlying optimizer object.
+
+ :return: The optimizer instance.
+ :rtype: object
+ """
diff --git a/pina/_src/optim/scheduler_interface.py b/pina/_src/optim/scheduler_interface.py
new file mode 100644
index 000000000..55951ee0e
--- /dev/null
+++ b/pina/_src/optim/scheduler_interface.py
@@ -0,0 +1,31 @@
+"""Module for the Scheduler Interface."""
+
+from abc import ABCMeta, abstractmethod
+
+
+class SchedulerInterface(metaclass=ABCMeta):
+ """
+ Abstract interface for all schedulers.
+ """
+
+ @abstractmethod
+ def hook(self, optimizer):
+ """
+ Execute custom logic associated with the scheduler instance.
+
+ This method is intended to encapsulate any additional behavior that
+ should be triggered during the optimization process.
+
+ :param OptimizerInterface optimizer: The optimizer instance associated
+ with the scheduler.
+ """
+
+ @property
+ @abstractmethod
+ def instance(self):
+ """
+ The underlying scheduler object.
+
+ :return: The scheduler instance.
+ :rtype: object
+ """
diff --git a/pina/_src/optim/torch_optimizer.py b/pina/_src/optim/torch_optimizer.py
new file mode 100644
index 000000000..a37bfbfec
--- /dev/null
+++ b/pina/_src/optim/torch_optimizer.py
@@ -0,0 +1,59 @@
+"""Module for wrapping PyTorch optimizers."""
+
+import torch
+from pina._src.core.utils import check_consistency
+from pina._src.optim.optimizer_interface import OptimizerInterface
+
+
+class TorchOptimizer(OptimizerInterface):
+ """
+ The wrapper class for PyTorch optimizers.
+
+ This class wraps a ``torch.optim.Optimizer`` class and defers its
+ instantiation until runtime. It enables a consistent interface across
+ different optimizer backends while leveraging PyTorch’s optimization
+ algorithms.
+ """
+
+ def __init__(self, optimizer_class, **kwargs):
+ """
+ Initialization of the :class:`TorchOptimizer` class.
+
+ :param torch.optim.Optimizer optimizer_class: The subclass of
+ ``torch.optim.Optimizer`` to be instantiated.
+ :param dict kwargs: Additional keyword arguments forwarded to the
+ optimizer constructor. See more
+ `here `_.
+ :raises ValueError: If ``optimizer_class`` is not a subclass of
+ ``torch.optim.Optimizer``.
+ """
+ # Check consistency
+ check_consistency(optimizer_class, torch.optim.Optimizer, subclass=True)
+
+ # Initialize attributes
+ self.optimizer_class = optimizer_class
+ self.kwargs = kwargs
+ self._optimizer_instance = None
+
+ def hook(self, parameters):
+ """
+ Execute custom logic associated with the optimizer instance.
+
+ This method is intended to encapsulate any additional behavior that
+ should be triggered during the optimization process.
+
+ :param dict parameters: The parameters of the model to be optimized.
+ """
+ self._optimizer_instance = self.optimizer_class(
+ parameters, **self.kwargs
+ )
+
+ @property
+ def instance(self):
+ """
+ The underlying optimizer object.
+
+ :return: The optimizer instance.
+ :rtype: torch.optim.Optimizer
+ """
+ return self._optimizer_instance
diff --git a/pina/_src/optim/torch_scheduler.py b/pina/_src/optim/torch_scheduler.py
new file mode 100644
index 000000000..f33b6020f
--- /dev/null
+++ b/pina/_src/optim/torch_scheduler.py
@@ -0,0 +1,62 @@
+"""Module for wrapping PyTorch schedulers."""
+
+from torch.optim.lr_scheduler import LRScheduler
+from pina._src.core.utils import check_consistency
+from pina._src.optim.optimizer_interface import OptimizerInterface
+from pina._src.optim.scheduler_interface import SchedulerInterface
+
+
+class TorchScheduler(SchedulerInterface):
+ """
+ The wrapper class for PyTorch schedulers.
+
+ This class wraps a ``torch.optim.lr_scheduler.LRScheduler`` class and defers
+ its instantiation until runtime, once the optimizer instance is available.
+ """
+
+ def __init__(self, scheduler_class, **kwargs):
+ """
+ Initialization of the :class:`TorchScheduler` class.
+
+ :param torch.optim.LRScheduler scheduler_class: The subclass of
+ ``torch.optim.lr_scheduler.LRScheduler`` to be instantiated.
+ :param dict kwargs: Additional keyword arguments forwarded to the
+ scheduler constructor. See more
+ `here `_.
+ :raises ValueError: If ``scheduler_class`` is not a subclass of
+ ``torch.optim.lr_scheduler.LRScheduler``.
+ """
+ # Check consistency
+ check_consistency(scheduler_class, LRScheduler, subclass=True)
+
+ # Initialize attributes
+ self.scheduler_class = scheduler_class
+ self.kwargs = kwargs
+ self._scheduler_instance = None
+
+ def hook(self, optimizer):
+ """
+ Initialize the scheduler instance with the given parameters.
+
+ :param OptimizerInterface optimizer: The optimizer instance associated
+ with the scheduler.
+ :raises ValueError: If ``optimizer`` is not an instance of
+ :class:`OptimizerInterface`.
+ """
+ # Check consistency
+ check_consistency(optimizer, OptimizerInterface)
+
+ # Initialize the scheduler instance
+ self._scheduler_instance = self.scheduler_class(
+ optimizer.instance, **self.kwargs
+ )
+
+ @property
+ def instance(self):
+ """
+ The underlying scheduler object.
+
+ :return: The scheduler instance.
+ :rtype: torch.optim.lr_scheduler.LRScheduler
+ """
+ return self._scheduler_instance
diff --git a/pina/_src/problem/__init__.py b/pina/_src/problem/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pina/_src/problem/base_problem.py b/pina/_src/problem/base_problem.py
new file mode 100644
index 000000000..8bccc0d79
--- /dev/null
+++ b/pina/_src/problem/base_problem.py
@@ -0,0 +1,299 @@
+"""Module for the BaseProblem class."""
+
+import warnings
+from copy import deepcopy
+from pina._src.problem.problem_interface import ProblemInterface
+from pina._src.domain.domain_interface import DomainInterface
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.condition.condition import Condition
+from pina._src.condition.domain_equation_condition import (
+ DomainEquationCondition,
+)
+from pina._src.core.utils import (
+ check_consistency,
+ check_positive_integer,
+ merge_tensors,
+)
+
+
+class BaseProblem(ProblemInterface):
+ """
+ Base class for all problems, implementing common functionality.
+
+ A problem is defined by core components, including input and output
+ variables, a set of conditions to be satisfied, and optionally the domains
+ on which these conditions are defined.
+
+ All problems must inherit from this class and implement abstract methods
+ defined in :class:`~pina.problem.problem_interface.ProblemInterface`.
+
+ This class is not meant to be instantiated directly.
+ """
+
+ def __init__(self):
+ """
+ Initialization of the :class:`BaseProblem` class.
+ """
+ self._discretised_domains = {}
+
+ # Create a correspondence between the problem and the conditions
+ for condition_name in self.conditions:
+ self.conditions[condition_name].problem = self
+ self.conditions[condition_name].name = condition_name
+
+ # Create a dictionary to store the domains of the problem
+ if not hasattr(self, "domains"):
+ self.domains = {}
+
+ # Store all the domains object passed to the problem's conditions
+ for name, cond in self.conditions.items():
+ if isinstance(cond, DomainEquationCondition):
+ if isinstance(cond.domain, DomainInterface):
+ self.domains[name] = cond.domain
+ cond.domain = name
+
+ def __deepcopy__(self, memo):
+ """
+ Create a deep copy of the problem instance.
+
+ :param dict memo: The memorization dictionary used by the deepcopy
+ function.
+ :return: A deep copy of the problem instance.
+ :rtype: ProblemInterface
+ """
+ # Create a new instance of the same class and store it in a dictionary
+ result = self.__class__.__new__(self.__class__)
+ memo[id(self)] = result
+
+ # Set the attributes of the new instance to deep copies of the original
+ for k, v in self.__dict__.items():
+ setattr(result, k, deepcopy(v, memo))
+
+ return result
+
+ def discretise_domain(
+ self, n=None, mode="random", domains=None, sample_rules=None
+ ):
+ """
+ Discretise the problem's domains by sampling a specified number of
+ points according to the selected sampling mode.
+
+ :param int n: The number of points to sample. This is ignored if
+ ``sample_rules`` is provided. Default is ``None``.
+ :param str mode: The sampling method. Available modes include:
+ ``"random"`` for random sampling, ``"latin"`` or ``"lh"`` for latin
+ hypercube sampling, ``"chebyshev"`` for Chebyshev sampling, and
+ ``"grid"`` for grid sampling. Default is ``"random"``.
+ :param domains: The domains from which to sample. If ``None``, all
+ domains are considered for sampling. Default is ``None``.
+ :type domains: str | list[str]
+ :param dict sample_rules: The dictionary specifying custom sampling
+ rules for each input variable. When provided, it overrides the
+ global ``n`` and ``mode`` arguments. Each key in the dictionary must
+ match one of the variables defined in :meth:`input_variables`, and
+ each value must be a dictionary containing two keys: ``n`` for the
+ number of points to sample for that variable, and ``mode`` for the
+ sampling method to use. If ``None``, the global ``n`` and ``mode``
+ parameters are used for all variables. Default is ``None``.
+ :raises ValueError: If ``sample_rules`` is provided but it is not a
+ dictionary.
+ :raises ValueError: If ``sample_rules`` is provided but its keys do not
+ match the input variables of the problem.
+ :raises ValueError: If ``sample_rules`` is provided but any of its rules
+ is not a dictionary containing both ``n`` and ``mode`` keys, with
+ ``n`` being a positive integer and ``mode`` being a string.
+ :raises AssertionError: If ``n`` is not a positive integer.
+ :raises ValueError: If ``mode`` is not a string
+ :raises ValueError: If ``domains`` is provided by it is neither a string
+ nor a list of strings.
+
+ .. warning::
+ ``"random"`` is the only supported ``mode`` across all geometries:
+ :class:`~pina.domain.cartesian_domain.CartesianDomain`,
+ :class:`~pina.domain.ellipsoid_domain.EllipsoidDomain`, and
+ :class:`~pina.domain.simplex_domain.SimplexDomain`.
+ Sampling modes such as ``"latin"``, ``"chebyshev"``, and ``"grid"``
+ are only implemented for
+ :class:`~pina.domain.cartesian_domain.CartesianDomain`.
+ When custom discretisation is specified via ``sample_rules``, the
+ domain to be discretised must be an instance of
+ :class:`~pina.domain.cartesian_domain.CartesianDomain`.
+
+ :Example:
+ >>> problem.discretise_domain(n=10, mode="random")
+ >>> problem.discretise_domain(n=10, mode="lh", domains=["boundary"])
+ >>> problem.discretise_domain(
+ ... sample_rules={
+ ... 'x': {'n': 10, 'mode': 'grid'},
+ ... 'y': {'n': 100, 'mode': 'grid'}
+ ... },
+ ... )
+ """
+ # Initialize the domains to be discretised
+ if domains is None:
+ domains = list(self.domains)
+ if not isinstance(domains, (list)):
+ domains = [domains]
+
+ # Check sampling rules
+ if sample_rules is not None:
+ check_consistency(sample_rules, dict)
+
+ # Check that the keys of sample_rules match the input variables
+ if sorted(list(sample_rules.keys())) != sorted(
+ self.input_variables
+ ):
+ raise ValueError(
+ "The keys of the sample_rules dictionary must match the "
+ "input variables."
+ )
+
+ # Check that the rules for each variable are valid
+ for var, rules in sample_rules.items():
+ check_consistency(rules, dict)
+ if "n" not in rules or "mode" not in rules:
+ raise ValueError(
+ f"Sampling rules for variable {var} must contain 'n' "
+ "and 'mode' keys."
+ )
+ check_positive_integer(rules["n"], strict=True)
+ check_consistency(rules["mode"], str)
+
+ # Check n only if sample_rules is not provided
+ else:
+ check_positive_integer(n, strict=True)
+
+ # Check consistency
+ check_consistency(mode, str)
+ check_consistency(domains, str)
+
+ # If sample_rules is provided, apply custom discretisation
+ if sample_rules is not None:
+ for d in domains:
+
+ # Discretise each variable according to its custom rules
+ discretised_tensor = [
+ self.domains[d].sample(rules["n"], rules["mode"], var)
+ for var, rules in sample_rules.items()
+ ]
+
+ # Merge the discretised tensors into a single one for the domain
+ self.discretised_domains[d] = merge_tensors(discretised_tensor)
+
+ # Otherwise, apply the same n and mode to all specified domains
+ else:
+ for d in domains:
+ self.discretised_domains[d] = self.domains[d].sample(n, mode)
+
+ def add_points(self, new_points_dict):
+ """
+ Append additional points to an already discretised domain.
+
+ :param dict new_points_dict: The dictionary mapping each domain to the
+ corresponding set of new points to be added. Each key in the
+ dictionary must match one of the domains defined in :attr:`domains`,
+ and each value must be a :class:`~pina.tensor.LabelTensor`
+ containing the new points to be added to that domain. The labels of
+ the points to be added must correspond to those of the domain to
+ which they are being added.
+ :raises ValueError: If ``new_points_dict`` is not a dictionary.
+ :raises ValueError: If any of the values in ``new_points_dict`` is not
+ a :class:`~pina.tensor.LabelTensor`.
+ :raises ValueError: If any of the keys in ``new_points_dict`` does not
+ match any of the domains defined in :attr:`domains`.
+ :raises ValueError: If any of the domains in ``new_points_dict`` has not
+ been discretised yet.
+
+ :Example:
+ >>> additional_points = {
+ ... "boundary": LabelTensor(torch.rand(5, 2), labels=["x", "y"])
+ ... }
+ >>> problem.add_points(additional_points)
+ """
+ # Check consistency
+ check_consistency(new_points_dict, dict)
+
+ # Check the keys and values of the dictionary
+ for key, value in new_points_dict.items():
+ check_consistency(value, LabelTensor)
+ if key not in self.domains:
+ raise ValueError(
+ f"Key {key} does not match any domain of the problem."
+ )
+ if key not in self.discretised_domains:
+ raise ValueError(f"Domain {key} has not been discretised yet.")
+
+ # Append the new points to the corresponding discretised domains
+ for key, value in new_points_dict.items():
+ self.discretised_domains[key] = LabelTensor.vstack(
+ [self.discretised_domains[key], value]
+ )
+
+ def move_discretisation_into_conditions(self):
+ """
+ Move the sampled points from the discretised domains into their
+ corresponding conditions. This ensures that the conditions are evaluated
+ on the correct set of points after discretisation.
+ """
+ # Move the discretised domains into their corresponding conditions
+ for name, cond in self.conditions.items():
+ if hasattr(cond, "domain"):
+
+ # Create a new condition with the discretised domain as input
+ new_condition = Condition(
+ input=self.discretised_domains[cond.domain],
+ equation=cond.equation,
+ )
+
+ # Set the domain and problem attributes of the new condition
+ new_condition.domain = cond.domain
+ new_condition.problem = self
+ new_condition.name = name
+
+ # Replace the old condition in the conditions dictionary
+ self.conditions[name] = new_condition
+
+ @property
+ def input_variables(self):
+ """
+ The input variables of the problem.
+
+ :return: The input variables of the problem.
+ :rtype: list[str]
+ """
+ # Define a helper function to convert a string to a list if needed
+ _as_list = lambda x: [x] if isinstance(x, str) else x
+
+ # Collect the spatial, temporal, and parametric variables
+ variables = []
+ if hasattr(self, "spatial_variables"):
+ variables += _as_list(self.spatial_variables)
+ if hasattr(self, "temporal_variables"):
+ variables += _as_list(self.temporal_variables)
+ if hasattr(self, "parameters"):
+ variables += _as_list(self.parameters)
+
+ return variables
+
+ @property
+ def discretised_domains(self):
+ """
+ The dictionary containing the discretised domains of the problem. Each
+ key corresponds to a domain defined in :attr:`domains`, and each value
+ is a :class:`~pina.tensor.LabelTensor` containing the sampled points for
+ that domain.
+
+ :return: The discretised domains.
+ :rtype: dict
+ """
+ return self._discretised_domains
+
+ @property
+ def are_all_domains_discretised(self):
+ """
+ Whether all domains of the problem have been discretised.
+
+ :return: ``True`` if all domains are discretised, ``False`` otherwise.
+ :rtype: bool
+ """
+ return all(d in self.discretised_domains for d in self.domains)
diff --git a/pina/problem/inverse_problem.py b/pina/_src/problem/inverse_problem.py
similarity index 63%
rename from pina/problem/inverse_problem.py
rename to pina/_src/problem/inverse_problem.py
index 8a2902448..7ee28bb96 100644
--- a/pina/problem/inverse_problem.py
+++ b/pina/_src/problem/inverse_problem.py
@@ -2,13 +2,18 @@
from abc import abstractmethod
import torch
-from .abstract_problem import AbstractProblem
+from pina._src.problem.base_problem import BaseProblem
-class InverseProblem(AbstractProblem):
+class InverseProblem(BaseProblem):
"""
- Class for defining inverse problems, where the objective is to determine
- unknown parameters through training, based on given data.
+ Base class for all inverse problems, extending the standard problem
+ definition with unknown parameters to be determined through training.
+
+ An inverse problem is defined by a set of unknown parameters that need to be
+ estimated from observed data.
+
+ This class is not meant to be instantiated directly.
"""
def __init__(self):
@@ -16,15 +21,15 @@ def __init__(self):
Initialization of the :class:`InverseProblem` class.
"""
super().__init__()
- # storing unknown_parameters for optimization
+
+ # Set the unknown parameters as trainable parameters
self.unknown_parameters = {}
for var in self.unknown_variables:
- range_var = self.unknown_parameter_domain._range[var]
- tensor_var = (
- torch.rand(1, requires_grad=True) * range_var[1] + range_var[0]
- )
+ low, high = self.unknown_parameter_domain._range[var]
+ tensor_var = low + (high - low) * torch.rand(1)
self.unknown_parameters[var] = torch.nn.Parameter(tensor_var)
+ @property
@abstractmethod
def unknown_parameter_domain(self):
"""
@@ -34,7 +39,7 @@ def unknown_parameter_domain(self):
@property
def unknown_variables(self):
"""
- Get the unknown variables of the problem.
+ The unknown variables of the problem.
:return: The unknown variables of the problem.
:rtype: list[str]
@@ -44,7 +49,7 @@ def unknown_variables(self):
@property
def unknown_parameters(self):
"""
- Get the unknown parameters of the problem.
+ The unknown parameters of the problem.
:return: The unknown parameters of the problem.
:rtype: torch.nn.Parameter
diff --git a/pina/_src/problem/parametric_problem.py b/pina/_src/problem/parametric_problem.py
new file mode 100644
index 000000000..12a9cd089
--- /dev/null
+++ b/pina/_src/problem/parametric_problem.py
@@ -0,0 +1,35 @@
+"""Module for the ParametricProblem class."""
+
+from abc import abstractmethod
+from pina._src.problem.base_problem import BaseProblem
+
+
+class ParametricProblem(BaseProblem):
+ """
+ Base class for all parametric problems, extending the standard problem
+ definition with parameter-dependent inputs.
+
+ A parametric problem includes additional input variables, defined over a
+ dedicated parameter domain, which represent external quantities
+ (e.g., physical coefficients or control variables) that can vary across
+ different evaluations and influence the solution.
+
+ This class is not meant to be instantiated directly.
+ """
+
+ @property
+ @abstractmethod
+ def parameter_domain(self):
+ """
+ The domain of the parameters of the problem.
+ """
+
+ @property
+ def parameters(self):
+ """
+ The parameters of the problem.
+
+ :return: The parameters of the problem.
+ :rtype: list[str]
+ """
+ return self.parameter_domain.variables
diff --git a/pina/_src/problem/problem_interface.py b/pina/_src/problem/problem_interface.py
new file mode 100644
index 000000000..d64130d61
--- /dev/null
+++ b/pina/_src/problem/problem_interface.py
@@ -0,0 +1,150 @@
+"""Module for the Problem Interface."""
+
+from abc import ABCMeta, abstractmethod
+
+
+class ProblemInterface(metaclass=ABCMeta):
+ """
+ Abstract interface for all problems.
+ """
+
+ @abstractmethod
+ def __deepcopy__(self, memo):
+ """
+ Create a deep copy of the problem instance.
+
+ :param dict memo: The memorization dictionary used by the deepcopy
+ function.
+ :return: A deep copy of the problem instance.
+ :rtype: ProblemInterface
+ """
+
+ @abstractmethod
+ def discretise_domain(
+ self, n=None, mode="random", domains=None, sample_rules=None
+ ):
+ """
+ Discretise the problem's domains by sampling a specified number of
+ points according to the selected sampling mode.
+
+ :param int n: The number of points to sample. This is ignored if
+ ``sample_rules`` is provided. Default is ``None``.
+ :param str mode: The sampling method. Available modes include:
+ ``"random"`` for random sampling, ``"latin"`` or ``"lh"`` for latin
+ hypercube sampling, ``"chebyshev"`` for Chebyshev sampling, and
+ ``"grid"`` for grid sampling. Default is ``"random"``.
+ :param domains: The domains from which to sample. If ``None``, all
+ domains are considered for sampling. Default is ``None``.
+ :type domains: str | list[str]
+ :param dict sample_rules: The dictionary specifying custom sampling
+ rules for each input variable. When provided, it overrides the
+ global ``n`` and ``mode`` arguments. Each key in the dictionary must
+ match one of the variables defined in :meth:`input_variables`, and
+ each value must be a dictionary containing two keys: ``n`` for the
+ number of points to sample for that variable, and ``mode`` for the
+ sampling method to use. If ``None``, the global ``n`` and ``mode``
+ parameters are used for all variables. Default is ``None``.
+
+ .. warning::
+ ``"random"`` is the only supported ``mode`` across all geometries:
+ :class:`~pina.domain.cartesian_domain.CartesianDomain`,
+ :class:`~pina.domain.ellipsoid_domain.EllipsoidDomain`, and
+ :class:`~pina.domain.simplex_domain.SimplexDomain`.
+ Sampling modes such as ``"latin"``, ``"chebyshev"``, and ``"grid"``
+ are only implemented for
+ :class:`~pina.domain.cartesian_domain.CartesianDomain`.
+ When custom discretisation is specified via ``sample_rules``, the
+ domain to be discretised must be an instance of
+ :class:`~pina.domain.cartesian_domain.CartesianDomain`.
+
+ :Example:
+ >>> problem.discretise_domain(n=10, mode="random")
+ >>> problem.discretise_domain(n=10, mode="lh", domains=["boundary"])
+ >>> problem.discretise_domain(
+ ... sample_rules={
+ ... 'x': {'n': 10, 'mode': 'grid'},
+ ... 'y': {'n': 100, 'mode': 'grid'}
+ ... },
+ ... )
+ """
+
+ @abstractmethod
+ def add_points(self, new_points_dict):
+ """
+ Append additional points to an already discretised domain.
+
+ :param dict new_points_dict: The dictionary mapping each domain to the
+ corresponding set of new points to be added. Each key in the
+ dictionary must match one of the domains defined in :attr:`domains`,
+ and each value must be a :class:`~pina.tensor.LabelTensor`
+ containing the new points to be added to that domain. The labels of
+ the points to be added must correspond to those of the domain to
+ which they are being added.
+
+ :Example:
+ >>> additional_points = {
+ ... "boundary": LabelTensor(torch.rand(5, 2), labels=["x", "y"])
+ ... }
+ >>> problem.add_points(additional_points)
+ """
+
+ @abstractmethod
+ def move_discretisation_into_conditions(self):
+ """
+ Move the sampled points from the discretised domains into their
+ corresponding conditions. This ensures that the conditions are evaluated
+ on the correct set of points after discretisation.
+ """
+
+ @property
+ @abstractmethod
+ def input_variables(self):
+ """
+ The input variables of the problem.
+
+ :return: The input variables of the problem.
+ :rtype: list[str]
+ """
+
+ @property
+ @abstractmethod
+ def output_variables(self):
+ """
+ The output variables of the problem.
+
+ :return: The output variables of the problem.
+ :rtype: list[str]
+ """
+
+ @property
+ @abstractmethod
+ def conditions(self):
+ """
+ The conditions associated with the problem.
+
+ :return: The conditions associated with the problem.
+ :rtype: dict
+ """
+
+ @property
+ @abstractmethod
+ def discretised_domains(self):
+ """
+ The dictionary containing the discretised domains of the problem. Each
+ key corresponds to a domain defined in :attr:`domains`, and each value
+ is a :class:`~pina.tensor.LabelTensor` containing the sampled points for
+ that domain.
+
+ :return: The discretised domains.
+ :rtype: dict
+ """
+
+ @property
+ @abstractmethod
+ def are_all_domains_discretised(self):
+ """
+ Whether all domains of the problem have been discretised.
+
+ :return: ``True`` if all domains are discretised, ``False`` otherwise.
+ :rtype: bool
+ """
diff --git a/pina/_src/problem/spatial_problem.py b/pina/_src/problem/spatial_problem.py
new file mode 100644
index 000000000..16ea9365b
--- /dev/null
+++ b/pina/_src/problem/spatial_problem.py
@@ -0,0 +1,34 @@
+"""Module for the SpatialProblem class."""
+
+from abc import abstractmethod
+from pina._src.problem.base_problem import BaseProblem
+
+
+class SpatialProblem(BaseProblem):
+ """
+ Base class for all spatial problems, extending the standard problem
+ definition with spatial-dependent inputs.
+
+ A spatial problem is defined over a spatial domain, where input variables
+ represent the coordinates of the system (e.g., positions in one or more
+ dimensions) on which the solution is evaluated.
+
+ This class is not meant to be instantiated directly.
+ """
+
+ @property
+ @abstractmethod
+ def spatial_domain(self):
+ """
+ The domain of spatial variables of the problem.
+ """
+
+ @property
+ def spatial_variables(self):
+ """
+ The spatial input variables of the problem.
+
+ :return: The spatial input variables of the problem.
+ :rtype: list[str]
+ """
+ return self.spatial_domain.variables
diff --git a/pina/_src/problem/time_dependent_problem.py b/pina/_src/problem/time_dependent_problem.py
new file mode 100644
index 000000000..b81ab4778
--- /dev/null
+++ b/pina/_src/problem/time_dependent_problem.py
@@ -0,0 +1,33 @@
+"""Module for the TimeDependentProblem class."""
+
+from abc import abstractmethod
+from pina._src.problem.base_problem import BaseProblem
+
+
+class TimeDependentProblem(BaseProblem):
+ """
+ Base class for all time-dependent problems, extending the standard problem
+ definition with time-dependent inputs.
+
+ A time-dependent problem is defined over a temporal domain, where input
+ variables represent the time at which the solution is evaluated.
+
+ This class is not meant to be instantiated directly.
+ """
+
+ @property
+ @abstractmethod
+ def temporal_domain(self):
+ """
+ The domain of temporal variables of the problem.
+ """
+
+ @property
+ def temporal_variables(self):
+ """
+ The temporal variables of the problem.
+
+ :return: The temporal variables of the problem.
+ :rtype: list[str]
+ """
+ return self.temporal_domain.variables
diff --git a/pina/_src/problem/zoo/__init__.py b/pina/_src/problem/zoo/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pina/problem/zoo/acoustic_wave.py b/pina/_src/problem/zoo/acoustic_wave_problem.py
similarity index 56%
rename from pina/problem/zoo/acoustic_wave.py
rename to pina/_src/problem/zoo/acoustic_wave_problem.py
index b4b2035a4..f2f7cff8b 100644
--- a/pina/problem/zoo/acoustic_wave.py
+++ b/pina/_src/problem/zoo/acoustic_wave_problem.py
@@ -1,17 +1,16 @@
"""Formulation of the acoustic wave problem."""
import torch
-from ... import Condition
-from ...problem import SpatialProblem, TimeDependentProblem
-from ...utils import check_consistency
-from ...domain import CartesianDomain
-from ...equation import (
- Equation,
- SystemEquation,
- FixedValue,
- FixedGradient,
- AcousticWave,
-)
+from pina._src.problem.time_dependent_problem import TimeDependentProblem
+from pina._src.domain.cartesian_domain import CartesianDomain
+from pina._src.equation.system_equation import SystemEquation
+from pina._src.problem.spatial_problem import SpatialProblem
+from pina._src.condition.condition import Condition
+from pina._src.core.utils import check_consistency
+from pina._src.equation.equation import Equation
+from pina._src.equation.zoo.fixed_value import FixedValue
+from pina._src.equation.zoo.fixed_gradient import FixedGradient
+from pina._src.equation.zoo.acoustic_wave_equation import AcousticWaveEquation
def initial_condition(input_, output_):
@@ -29,8 +28,50 @@ def initial_condition(input_, output_):
class AcousticWaveProblem(TimeDependentProblem, SpatialProblem):
r"""
- Implementation of the acoustic wave problem in the spatial interval
- :math:`[0, 1]` and temporal interval :math:`[0, 1]`.
+ Implementation of the one-dimensional acoustic wave problem on the
+ space-time domain :math:`\Omega\times T = [0, 1] \times [0, 1]`.
+
+ The problem is governed by the acoustic wave equation
+
+ .. math::
+
+ \frac{\partial^2 u}{\partial t^2}
+ =
+ c^2 \frac{\partial^2 u}{\partial x^2},
+
+ where :math:`u = u(x, t)` is the solution field and :math:`c > 0` is the
+ wave propagation speed.
+
+ Homogeneous Dirichlet boundary conditions are imposed at the spatial
+ boundaries:
+
+ .. math::
+
+ u(0, t) = u(1, t) = 0, \qquad t \in [0, 1].
+
+ The initial displacement is prescribed as
+
+ .. math::
+
+ u(x, 0) = \sin(\pi x) + \frac{1}{2}\sin(4\pi x),
+ \qquad x \in [0, 1],
+
+ together with zero initial velocity:
+
+ .. math::
+
+ \frac{\partial u}{\partial t}(x, 0) = 0,
+ \qquad x \in [0, 1].
+
+ The analytical solution is given by
+
+ .. math::
+
+ u(x, t)
+ =
+ \sin(\pi x)\cos(c\pi t)
+ +
+ \frac{1}{2}\sin(4\pi x)\cos(4c\pi t).
.. seealso::
@@ -69,7 +110,7 @@ def __init__(self, c=2.0):
"""
Initialization of the :class:`AcousticWaveProblem` class.
- :param c: The wave propagation speed. Default is 2.0.
+ :param c: The wave propagation speed. Default is ``2.0``.
:type c: float | int
"""
super().__init__()
@@ -77,7 +118,7 @@ def __init__(self, c=2.0):
self.c = c
self.conditions["D"] = Condition(
- domain="D", equation=AcousticWave(self.c)
+ domain="D", equation=AcousticWaveEquation(self.c)
)
def solution(self, pts):
@@ -92,4 +133,7 @@ def solution(self, pts):
arg_t = self.c * torch.pi * pts["t"]
term1 = torch.sin(arg_x) * torch.cos(arg_t)
term2 = 0.5 * torch.sin(4 * arg_x) * torch.cos(4 * arg_t)
- return term1 + term2
+
+ sol = term1 + term2
+ sol.labels = self.output_variables
+ return sol
diff --git a/pina/problem/zoo/advection.py b/pina/_src/problem/zoo/advection_problem.py
similarity index 57%
rename from pina/problem/zoo/advection.py
rename to pina/_src/problem/zoo/advection_problem.py
index c709b9632..113b36bee 100644
--- a/pina/problem/zoo/advection.py
+++ b/pina/_src/problem/zoo/advection_problem.py
@@ -1,11 +1,13 @@
"""Formulation of the advection problem."""
import torch
-from ... import Condition
-from ...problem import SpatialProblem, TimeDependentProblem
-from ...equation import Equation, Advection
-from ...utils import check_consistency
-from ...domain import CartesianDomain
+from pina._src.problem.time_dependent_problem import TimeDependentProblem
+from pina._src.domain.cartesian_domain import CartesianDomain
+from pina._src.problem.spatial_problem import SpatialProblem
+from pina._src.equation.zoo.advection_equation import AdvectionEquation
+from pina._src.condition.condition import Condition
+from pina._src.core.utils import check_consistency
+from pina._src.equation.equation import Equation
def initial_condition(input_, output_):
@@ -22,8 +24,39 @@ def initial_condition(input_, output_):
class AdvectionProblem(SpatialProblem, TimeDependentProblem):
r"""
- Implementation of the advection problem in the spatial interval
- :math:`[0, 2 \pi]` and temporal interval :math:`[0, 1]`.
+ Implementation of the one-dimensional advection problem on the space-time
+ domain :math:`\Omega\times T = [0, 2\pi] \times [0, 1]`.
+
+ The problem is governed by the linear advection equation
+
+ .. math::
+
+ \frac{\partial u}{\partial t}
+ +
+ c \frac{\partial u}{\partial x}
+ =
+ 0,
+
+ where :math:`u = u(x, t)` is the solution field and :math:`c` is the
+ advection velocity.
+
+ Periodic boundary conditions are imposed at the spatial boundaries:
+
+ .. math::
+
+ u(0, t) = u(2\pi, t), \qquad t \in [0, 1].
+
+ The initial condition is prescribed as
+
+ .. math::
+
+ u(x, 0) = \sin(x), \qquad x \in [0, 2\pi].
+
+ The analytical solution is given by
+
+ .. math::
+
+ u(x, t) = \sin(x - ct).
.. seealso::
@@ -54,14 +87,16 @@ def __init__(self, c=1.0):
"""
Initialization of the :class:`AdvectionProblem`.
- :param c: The advection velocity parameter. Default is 1.0.
+ :param c: The advection velocity parameter. Default is ``1.0``.
:type c: float | int
"""
super().__init__()
check_consistency(c, (float, int))
self.c = c
- self.conditions["D"] = Condition(domain="D", equation=Advection(self.c))
+ self.conditions["D"] = Condition(
+ domain="D", equation=AdvectionEquation(self.c)
+ )
def solution(self, pts):
"""
diff --git a/pina/problem/zoo/allen_cahn.py b/pina/_src/problem/zoo/allen_cahn_problem.py
similarity index 57%
rename from pina/problem/zoo/allen_cahn.py
rename to pina/_src/problem/zoo/allen_cahn_problem.py
index 900d5cf33..6a7126e68 100644
--- a/pina/problem/zoo/allen_cahn.py
+++ b/pina/_src/problem/zoo/allen_cahn_problem.py
@@ -1,11 +1,13 @@
"""Formulation of the Allen Cahn problem."""
import torch
-from ... import Condition
-from ...problem import SpatialProblem, TimeDependentProblem
-from ...equation import Equation, AllenCahn
-from ...utils import check_consistency
-from ...domain import CartesianDomain
+from pina._src.condition.condition import Condition
+from pina._src.problem.spatial_problem import SpatialProblem
+from pina._src.problem.time_dependent_problem import TimeDependentProblem
+from pina._src.equation.equation import Equation
+from pina._src.equation.zoo.allen_cahn_equation import AllenCahnEquation
+from pina._src.core.utils import check_consistency
+from pina._src.domain.cartesian_domain import CartesianDomain
def initial_condition(input_, output_):
@@ -24,8 +26,35 @@ def initial_condition(input_, output_):
class AllenCahnProblem(TimeDependentProblem, SpatialProblem):
r"""
- Implementation of the Allen Cahn problem in the spatial interval
- :math:`[-1, 1]` and temporal interval :math:`[0, 1]`.
+ Implementation of the one-dimensional Allen-Cahn problem on the space-time
+ domain :math:`\Omega\times T = [-1, 1] \times [0, 1]`.
+
+ The problem is governed by the Allen-Cahn equation
+
+ .. math::
+
+ \frac{\partial u}{\partial t}
+ -
+ \alpha \frac{\partial^2 u}{\partial x^2}
+ +
+ \beta \left(u^3 - u\right)
+ =
+ 0,
+
+ where :math:`u = u(x, t)` is the solution field, :math:`\alpha` is the
+ diffusion coefficient, and :math:`\beta` is the reaction coefficient.
+
+ Periodic boundary conditions are imposed at the spatial boundaries:
+
+ .. math::
+
+ u(-1, t) = u(1, t), \qquad t \in [0, 1].
+
+ The initial condition is prescribed as
+
+ .. math::
+
+ u(x, 0) = x^2 \cos(\pi x), \qquad x \in [-1, 1].
.. seealso::
@@ -59,9 +88,9 @@ def __init__(self, alpha=1e-4, beta=5):
"""
Initialization of the :class:`AllenCahnProblem`.
- :param alpha: The diffusion coefficient. Default is 1e-4.
+ :param alpha: The diffusion coefficient. Default is ``1e-4``.
:type alpha: float | int
- :param beta: The reaction coefficient. Default is 5.0.
+ :param beta: The reaction coefficient. Default is ``5.0``.
:type beta: float | int
"""
super().__init__()
@@ -72,5 +101,5 @@ def __init__(self, alpha=1e-4, beta=5):
self.conditions["D"] = Condition(
domain="D",
- equation=AllenCahn(alpha=self.alpha, beta=self.beta),
+ equation=AllenCahnEquation(alpha=self.alpha, beta=self.beta),
)
diff --git a/pina/_src/problem/zoo/burgers_problem.py b/pina/_src/problem/zoo/burgers_problem.py
new file mode 100644
index 000000000..0ba779a22
--- /dev/null
+++ b/pina/_src/problem/zoo/burgers_problem.py
@@ -0,0 +1,102 @@
+"""Formulation of the Burgers' problem."""
+
+import torch
+from pina._src.problem.time_dependent_problem import TimeDependentProblem
+from pina._src.domain.cartesian_domain import CartesianDomain
+from pina._src.problem.spatial_problem import SpatialProblem
+from pina._src.condition.condition import Condition
+from pina._src.core.utils import check_consistency
+from pina._src.equation.equation import Equation
+from pina._src.equation.zoo.fixed_value import FixedValue
+from pina._src.equation.zoo.burgers_equation import BurgersEquation
+
+
+def initial_condition(input_, output_):
+ """
+ Definition of the initial condition of the Burgers' problem.
+
+ :param LabelTensor input_: The input data of the problem.
+ :param LabelTensor output_: The output data of the problem.
+ :return: The residual of the initial condition.
+ :rtype: LabelTensor
+ """
+ return output_ + torch.sin(torch.pi * input_["x"])
+
+
+class BurgersProblem(TimeDependentProblem, SpatialProblem):
+ r"""
+ Implementation of the one-dimensional Burgers' problem on the space-time
+ domain :math:`\Omega\times T = [-1, 1] \times [0, 1]`.
+
+ The problem is governed by the Burgers' equation
+
+ .. math::
+
+ \frac{\partial u}{\partial t} + u \frac{\partial u}{\partial x} =
+ \nu \frac{\partial^2 u}{\partial x^2},
+
+ where :math:`u = u(x, t)` is the solution field and :math:`\nu \geq 0`
+ is the viscosity coefficient. For :math:`\nu = 0`, the equation reduces
+ to the inviscid Burgers' equation.
+
+ Homogeneous Dirichlet boundary conditions are imposed at the spatial
+ boundaries:
+
+ .. math::
+ u(-1, t) = u(1, t) = 0, \qquad t \in [0, 1].
+
+ The initial condition is prescribed as
+
+ .. math::
+ u(x, 0) = -\sin(\pi x), \qquad x \in [-1, 1].
+
+
+ .. seealso::
+
+ **Original reference**: Raissi M., Perdikaris P., Karniadakis G. E.
+ (2017).
+ *Physics Informed Deep Learning (Part I): Data-driven Solutions of
+ Nonlinear Partial Differential Equations*.
+ DOI: `10.48550 `_.
+
+ :Example:
+
+ >>> problem = BurgersProblem()
+ """
+
+ output_variables = ["u"]
+ spatial_domain = CartesianDomain({"x": [-1, 1]})
+ temporal_domain = CartesianDomain({"t": [0, 1]})
+
+ domains = {
+ "D": spatial_domain.update(temporal_domain),
+ "t0": spatial_domain.update(CartesianDomain({"t": 0})),
+ "boundary": spatial_domain.partial().update(temporal_domain),
+ }
+
+ conditions = {
+ "boundary": Condition(domain="boundary", equation=FixedValue(0.0)),
+ "t0": Condition(domain="t0", equation=Equation(initial_condition)),
+ }
+
+ def __init__(self, nu=0):
+ """
+ Initialization of the :class:`BurgersProblem` class.
+
+ :param nu: The viscosity coefficient.
+ :type nu: float | int
+ :raises ValueError: If ``nu`` is not a float or an int.
+ :raises ValueError: If ``nu`` is negative.
+ """
+ super().__init__()
+
+ # Check consistency
+ check_consistency(nu, (float, int))
+ if nu < 0:
+ raise ValueError(
+ "The viscosity ``nu`` must be a non-negative float or int."
+ )
+
+ self.conditions["D"] = Condition(
+ domain="D", equation=BurgersEquation(nu)
+ )
diff --git a/pina/problem/zoo/diffusion_reaction.py b/pina/_src/problem/zoo/diffusion_reaction_problem.py
similarity index 57%
rename from pina/problem/zoo/diffusion_reaction.py
rename to pina/_src/problem/zoo/diffusion_reaction_problem.py
index fd02b8368..7a5584ca5 100644
--- a/pina/problem/zoo/diffusion_reaction.py
+++ b/pina/_src/problem/zoo/diffusion_reaction_problem.py
@@ -1,11 +1,16 @@
"""Formulation of the diffusion-reaction problem."""
import torch
-from ... import Condition
-from ...equation import Equation, FixedValue, DiffusionReaction
-from ...problem import SpatialProblem, TimeDependentProblem
-from ...utils import check_consistency
-from ...domain import CartesianDomain
+from pina._src.condition.condition import Condition
+from pina._src.equation.equation import Equation
+from pina._src.equation.zoo.fixed_value import FixedValue
+from pina._src.problem.spatial_problem import SpatialProblem
+from pina._src.problem.time_dependent_problem import TimeDependentProblem
+from pina._src.core.utils import check_consistency
+from pina._src.domain.cartesian_domain import CartesianDomain
+from pina._src.equation.zoo.diffusion_reaction_equation import (
+ DiffusionReactionEquation,
+)
def initial_condition(input_, output_):
@@ -30,8 +35,63 @@ def initial_condition(input_, output_):
class DiffusionReactionProblem(TimeDependentProblem, SpatialProblem):
r"""
- Implementation of the diffusion-reaction problem in the spatial interval
- :math:`[-\pi, \pi]` and temporal interval :math:`[0, 1]`.
+ Implementation of the one-dimensional diffusion-reaction problem on the
+ space-time domain :math:`\Omega\times T = [-\pi, \pi] \times [0, 1]`.
+
+ The problem is governed by the forced diffusion-reaction equation
+
+ .. math::
+
+ \frac{\partial u}{\partial t}
+ -
+ \alpha \frac{\partial^2 u}{\partial x^2}
+ =
+ f(x, t),
+
+ where :math:`u = u(x, t)` is the solution field, :math:`\alpha` is the
+ diffusion coefficient, and :math:`f(x, t)` is a forcing term.
+
+ Homogeneous Dirichlet boundary conditions are imposed at the spatial
+ boundaries:
+
+ .. math::
+
+ u(-\pi, t) = u(\pi, t) = 0, \qquad t \in [0, 1].
+
+ The initial condition is prescribed as
+
+ .. math::
+
+ u(x, 0)
+ =
+ \sin(x)
+ +
+ \frac{1}{2}\sin(2x)
+ +
+ \frac{1}{3}\sin(3x)
+ +
+ \frac{1}{4}\sin(4x)
+ +
+ \frac{1}{8}\sin(8x).
+
+ The analytical solution is given by
+
+ .. math::
+
+ u(x, t)
+ =
+ e^{-t}
+ \left(
+ \sin(x)
+ +
+ \frac{1}{2}\sin(2x)
+ +
+ \frac{1}{3}\sin(3x)
+ +
+ \frac{1}{4}\sin(4x)
+ +
+ \frac{1}{8}\sin(8x)
+ \right).
.. seealso::
@@ -63,7 +123,7 @@ def __init__(self, alpha=1e-4):
"""
Initialization of the :class:`DiffusionReactionProblem`.
- :param alpha: The diffusion coefficient. Default is 1e-4.
+ :param alpha: The diffusion coefficient. Default is ``1e-4``.
:type alpha: float | int
"""
super().__init__()
@@ -80,15 +140,16 @@ def forcing_term(input_):
t = input_.extract("t")
return torch.exp(-t) * (
- 1.5 * torch.sin(2 * x)
- + (8 / 3) * torch.sin(3 * x)
- + (15 / 4) * torch.sin(4 * x)
- + (63 / 8) * torch.sin(8 * x)
+ (self.alpha - 1) * torch.sin(x)
+ + ((4 * self.alpha - 1) / 2) * torch.sin(2 * x)
+ + ((9 * self.alpha - 1) / 3) * torch.sin(3 * x)
+ + ((16 * self.alpha - 1) / 4) * torch.sin(4 * x)
+ + ((64 * self.alpha - 1) / 8) * torch.sin(8 * x)
)
self.conditions["D"] = Condition(
domain="D",
- equation=DiffusionReaction(self.alpha, forcing_term),
+ equation=DiffusionReactionEquation(self.alpha, forcing_term),
)
def solution(self, pts):
diff --git a/pina/_src/problem/zoo/helmholtz_problem.py b/pina/_src/problem/zoo/helmholtz_problem.py
new file mode 100644
index 000000000..9b11519aa
--- /dev/null
+++ b/pina/_src/problem/zoo/helmholtz_problem.py
@@ -0,0 +1,124 @@
+"""Formulation of the Helmholtz problem."""
+
+import torch
+from pina._src.condition.condition import Condition
+from pina._src.equation.zoo.fixed_value import FixedValue
+from pina._src.equation.zoo.helmholtz_equation import HelmholtzEquation
+from pina._src.problem.spatial_problem import SpatialProblem
+from pina._src.core.utils import check_consistency
+from pina._src.domain.cartesian_domain import CartesianDomain
+
+
+class HelmholtzProblem(SpatialProblem):
+ r"""
+ Implementation of the two-dimensional Helmholtz problem on the square domain
+ :math:`\Omega = [-1, 1] \times [-1, 1]`.
+
+ The problem is governed by the forced Helmholtz equation
+
+ .. math::
+
+ \Delta u + k u = f(x, y),
+
+ where :math:`u = u(x, y)` is the solution field, :math:`k` is the squared
+ wavenumber, and :math:`f(x, y)` is a forcing term.
+
+ Homogeneous Dirichlet boundary conditions are imposed on the boundary of
+ the domain:
+
+ .. math::
+
+ u(x, y) = 0, \qquad (x, y) \in \partial \Omega.
+
+ The analytical solution is given by
+
+ .. math::
+
+ u(x, y)
+ =
+ \sin(\alpha_x \pi x)
+ \sin(\alpha_y \pi y),
+
+ with forcing term
+
+ .. math::
+
+ f(x, y)
+ =
+ \left[
+ k - (\alpha_x^2 + \alpha_y^2)\pi^2
+ \right]
+ \sin(\alpha_x \pi x)
+ \sin(\alpha_y \pi y).
+
+ .. seealso::
+
+ **Original reference**: Si, Chenhao, et al. *Complex Physics-Informed
+ Neural Network.* arXiv preprint arXiv:2502.04917 (2025).
+ DOI: `arXiv:2502.04917 `_.
+
+ :Example:
+
+ >>> problem = HelmholtzProblem()
+ """
+
+ output_variables = ["u"]
+ spatial_domain = CartesianDomain({"x": [-1, 1], "y": [-1, 1]})
+
+ domains = {
+ "D": spatial_domain,
+ "boundary": spatial_domain.partial(),
+ }
+
+ conditions = {
+ "boundary": Condition(domain="boundary", equation=FixedValue(0.0)),
+ }
+
+ def __init__(self, k=1.0, alpha_x=1, alpha_y=4):
+ """
+ Initialization of the :class:`HelmholtzProblem` class.
+
+ :param k: The squared wavenumber. Default is ``1.0``.
+ :type k: float | int
+ :param int alpha_x: The frequency in the x-direction. Default is ``1``.
+ :param int alpha_y: The frequency in the y-direction. Default is ``4``.
+ """
+ super().__init__()
+ check_consistency(k, (int, float))
+ check_consistency(alpha_x, int)
+ check_consistency(alpha_y, int)
+ self.k = k
+ self.alpha_x = alpha_x
+ self.alpha_y = alpha_y
+
+ def forcing_term(input_):
+ """
+ Implementation of the forcing term.
+ """
+ x, y, pi = input_["x"], input_["y"], torch.pi
+ factor = (self.alpha_x**2 + self.alpha_y**2) * pi**2
+ return (
+ (self.k - factor)
+ * torch.sin(self.alpha_x * pi * x)
+ * torch.sin(self.alpha_y * pi * y)
+ )
+
+ self.conditions["D"] = Condition(
+ domain="D",
+ equation=HelmholtzEquation(self.k, forcing_term),
+ )
+
+ def solution(self, pts):
+ """
+ Implementation of the analytical solution of the Helmholtz problem.
+
+ :param LabelTensor pts: Points where the solution is evaluated.
+ :return: The analytical solution of the Helmholtz problem.
+ :rtype: LabelTensor
+ """
+ x, y, pi = pts["x"], pts["y"], torch.pi
+ sol = torch.sin(self.alpha_x * pi * x) * torch.sin(
+ self.alpha_y * pi * y
+ )
+ sol.labels = self.output_variables
+ return sol
diff --git a/pina/problem/zoo/inverse_poisson_2d_square.py b/pina/_src/problem/zoo/inverse_poisson_problem.py
similarity index 75%
rename from pina/problem/zoo/inverse_poisson_2d_square.py
rename to pina/_src/problem/zoo/inverse_poisson_problem.py
index 17f30ae14..9a5a8d908 100644
--- a/pina/problem/zoo/inverse_poisson_2d_square.py
+++ b/pina/_src/problem/zoo/inverse_poisson_problem.py
@@ -4,13 +4,16 @@
import requests
import torch
from io import BytesIO
-from ... import Condition
-from ... import LabelTensor
-from ...operator import laplacian
-from ...domain import CartesianDomain
-from ...equation import Equation, FixedValue
-from ...problem import SpatialProblem, InverseProblem
-from ...utils import custom_warning_format, check_consistency
+
+from pina._src.core.utils import custom_warning_format, check_consistency
+from pina._src.domain.cartesian_domain import CartesianDomain
+from pina._src.problem.inverse_problem import InverseProblem
+from pina._src.problem.spatial_problem import SpatialProblem
+from pina._src.equation.zoo.fixed_value import FixedValue
+from pina._src.condition.condition import Condition
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.equation.equation import Equation
+from pina._src.core.operator import laplacian
warnings.formatwarning = custom_warning_format
warnings.filterwarnings("always", category=ResourceWarning)
@@ -28,7 +31,7 @@ def _load_tensor_from_url(url, labels, timeout=10):
:param str url: URL to the remote `.pth` tensor file.
:param labels: Labels for the resulting LabelTensor.
:type labels: list[str] | tuple[str]
- :param int timeout: Timeout for the request in seconds. Default is 10s.
+ :param int timeout: Timeout for the request in seconds. Default is ``10`` s.
:return: A LabelTensor object if successful, otherwise None.
:rtype: LabelTensor | None
"""
@@ -71,9 +74,33 @@ def laplace_equation(input_, output_, params_):
class InversePoisson2DSquareProblem(SpatialProblem, InverseProblem):
r"""
- Implementation of the inverse 2-dimensional Poisson problem in the square
- domain :math:`[0, 1] \times [0, 1]`, with unknown parameter domain
- :math:`[-1, 1] \times [-1, 1]`.
+ Implementation of the inverse two-dimensional Poisson problem on the square
+ domain :math:`\Omega = [-2, 2] \times [-2, 2]`, with unknown parameter
+ domain :math:`\Theta = [-1, 1] \times [-1, 1]`.
+
+ The problem is governed by the parameterized Poisson equation
+
+ .. math::
+
+ \Delta u
+ =
+ \exp\left(
+ -2(x - \mu_1)^2
+ -2(y - \mu_2)^2
+ \right),
+
+ where :math:`u = u(x, y)` is the solution field and :math:`\mu_1, \mu_2` are
+ unknown parameters controlling the forcing term.
+
+ Homogeneous Dirichlet boundary conditions are imposed on the boundary of the
+ domain:
+
+ .. math::
+
+ u(x, y) = 0, \qquad (x, y) \in \partial \Omega.
+
+ The inverse problem aims to infer the unknown parameters :math:`\mu_1` and
+ :math:`\mu_2` from solution data.
The `"data"` condition is added only if the required files are downloaded
successfully.
@@ -105,10 +132,10 @@ def __init__(self, load=True, data_size=1.0):
:param bool load: If True, it attempts to load data from remote URLs.
Set to False to skip data loading (e.g., if no internet connection).
- Default is True.
+ Default is ``True``.
:param float data_size: The fraction of the total data to use for the
"data" condition. If set to 1.0, all available data is used.
- If set to 0.0, no data is used. Default is 1.0.
+ If set to 0.0, no data is used. Default is ``1.0``.
:raises ValueError: If `data_size` is not in the range [0.0, 1.0].
:raises ValueError: If `data_size` is not a float.
"""
@@ -147,3 +174,5 @@ def __init__(self, load=True, data_size=1.0):
self.conditions["data"] = Condition(
input=input_data[:n_data], target=output_data[:n_data]
)
+ self.conditions["data"].problem = self
+ self.conditions["data"].name = "data"
diff --git a/pina/problem/zoo/poisson_2d_square.py b/pina/_src/problem/zoo/poisson_problem.py
similarity index 55%
rename from pina/problem/zoo/poisson_2d_square.py
rename to pina/_src/problem/zoo/poisson_problem.py
index 5de38b301..34d86c6fb 100644
--- a/pina/problem/zoo/poisson_2d_square.py
+++ b/pina/_src/problem/zoo/poisson_problem.py
@@ -1,10 +1,12 @@
"""Formulation of the Poisson problem in a square domain."""
import torch
-from ...equation import FixedValue, Poisson
-from ...problem import SpatialProblem
-from ...domain import CartesianDomain
-from ... import Condition
+
+from pina._src.equation.zoo.fixed_value import FixedValue
+from pina._src.domain.cartesian_domain import CartesianDomain
+from pina._src.problem.spatial_problem import SpatialProblem
+from pina._src.condition.condition import Condition
+from pina._src.equation.zoo.poisson_equation import PoissonEquation
def forcing_term(input_):
@@ -24,8 +26,40 @@ def forcing_term(input_):
class Poisson2DSquareProblem(SpatialProblem):
r"""
- Implementation of the 2-dimensional Poisson problem in the square domain
- :math:`[0, 1] \times [0, 1]`.
+ Implementation of the two-dimensional Poisson problem on the square domain
+ :math:`\Omega = [0, 1] \times [0, 1]`.
+
+ The problem is governed by the Poisson equation
+
+ .. math::
+
+ \Delta u = f(x, y),
+
+ where :math:`u = u(x, y)` is the solution field and :math:`f(x, y)` is the
+ forcing term.
+
+ Homogeneous Dirichlet boundary conditions are imposed on the boundary of the
+ domain:
+
+ .. math::
+
+ u(x, y) = 0, \qquad (x, y) \in \partial \Omega.
+
+ The forcing term is given by
+
+ .. math::
+
+ f(x, y)
+ =
+ 2\pi^2 \sin(\pi x)\sin(\pi y).
+
+ The analytical solution is given by
+
+ .. math::
+
+ u(x, y)
+ =
+ -\sin(\pi x)\sin(\pi y).
:Example:
@@ -42,7 +76,9 @@ class Poisson2DSquareProblem(SpatialProblem):
conditions = {
"boundary": Condition(domain="boundary", equation=FixedValue(0.0)),
- "D": Condition(domain="D", equation=Poisson(forcing_term=forcing_term)),
+ "D": Condition(
+ domain="D", equation=PoissonEquation(forcing_term=forcing_term)
+ ),
}
def solution(self, pts):
diff --git a/pina/problem/zoo/supervised_problem.py b/pina/_src/problem/zoo/supervised_problem.py
similarity index 77%
rename from pina/problem/zoo/supervised_problem.py
rename to pina/_src/problem/zoo/supervised_problem.py
index 61a49c0cb..fea7f80a3 100644
--- a/pina/problem/zoo/supervised_problem.py
+++ b/pina/_src/problem/zoo/supervised_problem.py
@@ -1,15 +1,14 @@
"""Formulation of a Supervised Problem in PINA."""
-from ..abstract_problem import AbstractProblem
-from ... import Condition
+from pina._src.problem.base_problem import BaseProblem
+from pina._src.condition.condition import Condition
-class SupervisedProblem(AbstractProblem):
+class SupervisedProblem(BaseProblem):
"""
Definition of a supervised-learning problem.
- This class provides a simple way to define a supervised problem
- using a single condition of type
+ This class provides a simple way to define a supervised problem using the
:class:`~pina.condition.input_target_condition.InputTargetCondition`.
:Example:
@@ -20,6 +19,9 @@ class SupervisedProblem(AbstractProblem):
>>> problem = SupervisedProblem(input_data, output_data)
"""
+ # TODO: This is necessary to override the abstract properties of
+ # BaseProblem, but it is not an ideal solution. We should consider
+ # a different desgin to manage input and output variables.
conditions = {}
output_variables = None
input_variables = None
@@ -36,10 +38,10 @@ def __init__(
:type output_: torch.Tensor | LabelTensor | Graph | Data
:param list[str] input_variables: List of names of the input variables.
If None, the input variables are inferred from `input_`.
- Default is None.
+ Default is ``None``.
:param list[str] output_variables: List of names of the output
variables. If None, the output variables are inferred from
- `output_`. Default is None.
+ `output_`. Default is ``None``.
"""
# Set input and output variables
self.input_variables = input_variables
diff --git a/pina/_src/solver/autoregressive_ensemble_solver.py b/pina/_src/solver/autoregressive_ensemble_solver.py
new file mode 100644
index 000000000..27e00947a
--- /dev/null
+++ b/pina/_src/solver/autoregressive_ensemble_solver.py
@@ -0,0 +1,117 @@
+"""Module for the autoregressive ensemble solver class."""
+
+from pina._src.solver.mixin.autoregressive_mixin import AutoregressiveMixin
+from pina._src.condition.time_series_condition import TimeSeriesCondition
+from pina._src.solver.ensemble_solver import EnsembleSolver
+
+
+class AutoregressiveEnsembleSolver(AutoregressiveMixin, EnsembleSolver):
+ r"""
+ Ensemble-model solver for autoregressive learning problems.
+
+ This solver learns the time evolution of dynamical systems using an
+ ensemble of models. It is intended for problems defined by time-series data
+ and accepts only
+ :class:`~pina._src.condition.time_series_condition.TimeSeriesCondition`.
+
+ Given a sequence of states :math:`\{\mathbf{u}_t\}_{t=0}^{T}`, the solver
+ trains an ensemble of models :math:`\{\mathcal{M}_j\}_{j=1}^{M}` to predict
+ the next state from the current one. The prediction of each model is
+
+ .. math::
+
+ \hat{\mathbf{u}}_{t+1}^{(j)} = \mathcal{M}_j(\mathbf{u}_t),
+ \qquad j = 1, \ldots, M.
+
+ The autoregressive training objective minimizes the discrepancy between
+ the predicted states :math:`\hat{\mathbf{u}}_{t+1}^{(j)}` and the target
+ states :math:`\mathbf{u}_{t+1}` over the sequence and across the ensemble:
+
+ .. math::
+
+ \mathcal{L}_{\mathrm{problem}} = \frac{1}{M} \sum_{j=1}^{M}
+ \frac{1}{T} \sum_{t=0}^{T-1} \mathcal{L}
+ \left( \mathbf{u}_{t+1} - \hat{\mathbf{u}}_{t+1}^{(j)} \right),
+
+ where :math:`\mathcal{L}` is the selected loss function, typically the
+ mean squared error.
+
+ The solver supports adaptive weighting of autoregressive steps through the
+ ``eps`` parameter. During training, each autoregressive step can contribute
+ differently to the total loss depending on its accumulated difficulty. Steps
+ with larger running losses are assigned larger weights, so that the solver
+ focuses more on parts of the rollout where prediction errors tend to
+ accumulate. The parameter ``eps`` controls the strength of this effect:
+ ``eps = 0`` disables adaptive weighting, while larger values increase the
+ influence of high-loss steps on the final training objective.
+ """
+
+ # Accepted conditions types for this solver
+ accepted_conditions_types = (TimeSeriesCondition,)
+
+ def __init__(
+ self,
+ problem,
+ models,
+ optimizers=None,
+ schedulers=None,
+ weighting=None,
+ loss=None,
+ use_lt=False,
+ eps=0.0,
+ reset_weights_at_epoch_start=True,
+ kwargs=None,
+ ):
+ """
+ Initialization of the :class:`AutoregressiveEnsembleSolver` class.
+
+ :param BaseProblem problem: The problem to be solved.
+ :param models: The model or list of models used by the solver.
+ :type models: torch.nn.Module | list[torch.nn.Module]
+ :param optimizers: The optimizer or list of optimizers used by the
+ solver. If ``None``, the ``torch.optim.Adam`` optimizer with a
+ learning rate of ``0.001`` is used for each model.
+ Default is ``None``.
+ :type optimizers: TorchOptimizer | list[TorchOptimizer]
+ :param schedulers: The scheduler or list of schedulers used by the
+ solver. If ``None``, the ``torch.optim.lr_scheduler.ConstantLR``
+ scheduler with a factor of ``1.0`` is used for each model.
+ Default is ``None``.
+ :type schedulers: TorchScheduler | list[TorchScheduler]
+ :param BaseWeighting weighting: The weighting strategy used to combine
+ condition losses. If ``None``, no weighting is applied. Default is
+ ``None``.
+ :param loss: The loss function used to compute residual losses.
+ If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``.
+ :param bool use_lt: If ``True``, the solver uses LabelTensors as input.
+ Default is ``False``.
+ :param eps: The hyperparameter controlling the influence of the
+ cumulative loss on the adaptive weights. Higher values of eps will
+ lead to more aggressive weighting of steps with higher cumulative
+ loss. Default is ``0.0``.
+ :type eps: float | int
+ :param bool reset_weights_at_epoch_start: Whether to reset the running
+ average and step count at the start of each epoch. If ``True``, the
+ adaptive weights will be recalibrated at the beginning of each epoch
+ based on the new training dynamics. Default is ``True``.
+ :param dict kwargs: Additional keyword arguments for preprocessing and
+ postprocessing steps.
+ """
+ # Initialize the parent class
+ EnsembleSolver.__init__(
+ self,
+ problem=problem,
+ models=models,
+ optimizers=optimizers,
+ schedulers=schedulers,
+ weighting=weighting,
+ loss=loss,
+ use_lt=use_lt,
+ )
+
+ # Initialize the autoregressive components
+ self._init_autoregressive_components(
+ eps=eps,
+ reset_weights_at_epoch_start=reset_weights_at_epoch_start,
+ kwargs=kwargs,
+ )
diff --git a/pina/_src/solver/autoregressive_single_model_solver.py b/pina/_src/solver/autoregressive_single_model_solver.py
new file mode 100644
index 000000000..8d2f0b9ca
--- /dev/null
+++ b/pina/_src/solver/autoregressive_single_model_solver.py
@@ -0,0 +1,111 @@
+"""Module for the autoregressive single model solver class."""
+
+from pina._src.solver.mixin.autoregressive_mixin import AutoregressiveMixin
+from pina._src.condition.time_series_condition import TimeSeriesCondition
+from pina._src.solver.single_model_solver import SingleModelSolver
+
+
+class AutoregressiveSingleModelSolver(AutoregressiveMixin, SingleModelSolver):
+ r"""
+ Single-model solver for autoregressive learning problems.
+
+ This solver learns the time evolution of dynamical systems using a single
+ model. It is intended for problems defined by time-series data and accepts
+ only
+ :class:`~pina._src.condition.time_series_condition.TimeSeriesCondition`.
+
+ Given a sequence of states :math:`\{\mathbf{u}_t\}_{t=0}^{T}`, the solver
+ trains a model :math:`\mathcal{M}` to predict the next state from the
+ current one:
+
+ .. math::
+
+ \hat{\mathbf{u}}_{t+1} = \mathcal{M}(\mathbf{u}_t).
+
+ The autoregressive training objective minimizes the discrepancy between
+ the predicted states :math:`\hat{\mathbf{u}}_{t+1}` and the target states
+ :math:`\mathbf{u}_{t+1}` over the sequence:
+
+ .. math::
+
+ \mathcal{L}_{\mathrm{problem}} = \frac{1}{T} \sum_{t=0}^{T-1}
+ \mathcal{L} \left( \mathbf{u}_{t+1} - \hat{\mathbf{u}}_{t+1} \right),
+
+ where :math:`\mathcal{L}` is the selected loss function, typically the mean
+ squared error.
+
+ The solver supports adaptive weighting of autoregressive steps through the
+ ``eps`` parameter. During training, each autoregressive step can contribute
+ differently to the total loss depending on its accumulated difficulty. Steps
+ with larger running losses are assigned larger weights, so that the solver
+ focuses more on parts of the rollout where prediction errors tend to
+ accumulate. The parameter ``eps`` controls the strength of this effect:
+ ``eps = 0`` disables adaptive weighting, while larger values increase the
+ influence of high-loss steps on the final training objective.
+ """
+
+ # Accepted conditions types for this solver
+ accepted_conditions_types = (TimeSeriesCondition,)
+
+ def __init__(
+ self,
+ problem,
+ model,
+ optimizer=None,
+ scheduler=None,
+ weighting=None,
+ loss=None,
+ use_lt=False,
+ eps=0.0,
+ reset_weights_at_epoch_start=True,
+ kwargs=None,
+ ):
+ """
+ Initialization of the :class:`AutoregressiveSingleModelSolver` class.
+
+ :param BaseProblem problem: The problem to be solved.
+ :param torch.nn.Module model: The model used by the solver.
+ :param TorchOptimizer optimizer: The optimizer used by the solver.
+ If ``None``, the ``torch.optim.Adam`` optimizer with a learning rate
+ of ``0.001`` is used. Default is ``None``.
+ :param TorchScheduler scheduler: The scheduler used by the solver.
+ If ``None``, the ``torch.optim.lr_scheduler.ConstantLR`` scheduler
+ with a factor of ``1.0`` is used. Default is ``None``.
+ :param BaseWeighting weighting: The weighting strategy used to combine
+ condition losses. If ``None``, no weighting is applied. Default is
+ ``None``.
+ :param loss: The loss function used to compute residual losses.
+ If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``.
+ :param bool use_lt: If ``True``, the solver uses LabelTensors as input.
+ Default is ``False``.
+ :param eps: The hyperparameter controlling the influence of the
+ cumulative loss on the adaptive weights. Higher values of eps will
+ lead to more aggressive weighting of steps with higher cumulative
+ loss. Default is ``0.0``.
+ :type eps: float | int
+ :param bool reset_weights_at_epoch_start: Whether to reset the running
+ average and step count at the start of each epoch. If ``True``, the
+ adaptive weights will be recalibrated at the beginning of each epoch
+ based on the new training dynamics. Default is ``True``.
+ :param dict kwargs: Additional keyword arguments for preprocessing and
+ postprocessing steps.
+ """
+
+ # Initialize the parent class
+ SingleModelSolver.__init__(
+ self,
+ problem=problem,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ weighting=weighting,
+ loss=loss,
+ use_lt=use_lt,
+ )
+
+ # Initialize the autoregressive components
+ self._init_autoregressive_components(
+ eps=eps,
+ reset_weights_at_epoch_start=reset_weights_at_epoch_start,
+ kwargs=kwargs,
+ )
diff --git a/pina/_src/solver/base_solver.py b/pina/_src/solver/base_solver.py
new file mode 100644
index 000000000..da6f5a60a
--- /dev/null
+++ b/pina/_src/solver/base_solver.py
@@ -0,0 +1,435 @@
+"""Module for the base solver class."""
+
+from abc import ABCMeta
+import lightning
+import torch
+from pina._src.core.utils import labelize_forward, check_consistency
+from pina._src.solver.solver_interface import SolverInterface
+from pina._src.weighting.base_weighting import BaseWeighting
+from pina._src.problem.inverse_problem import InverseProblem
+from pina._src.optim.torch_optimizer import TorchOptimizer
+from pina._src.optim.torch_scheduler import TorchScheduler
+from pina._src.weighting.no_weighting import _NoWeighting
+from pina._src.problem.base_problem import BaseProblem
+from pina._src.loss.base_dual_loss import BaseDualLoss
+
+
+class BaseSolver(SolverInterface, metaclass=ABCMeta):
+ """
+ Base class for all solvers, implementing common functionality.
+
+ All solvers must inherit from this class and implement abstract methods
+ defined in :class:`~pina.solver.solver_interface.SolverInterface`.
+
+ This class is not meant to be instantiated directly.
+ """
+
+ # Define the available reductions for loss computation
+ _AVAILABLE_REDUCTIONS = {
+ "none": lambda x: x,
+ "mean": lambda x: x.mean(),
+ "sum": lambda x: x.sum(),
+ }
+
+ def __init__(self, problem, use_lt=True):
+ """
+ Initialization of the :class:`BaseSolver` class.
+
+ :param BaseProblem problem: The problem to be solved.
+ :param bool use_lt: If ``True``, the solver uses LabelTensors as input.
+ Default is ``True``.
+ :raises ValueError: If ``use_lt`` is not a boolean.
+ :raises ValueError: If ``problem`` is not an instance of
+ :class:`~pina.problem.base_problem.BaseProblem`.
+ :raises ValueError: If one or more problem conditions are not supported
+ by the solver.
+ """
+ # Reset the solver state
+ self.reset()
+
+ # Call the parent class initializer
+ lightning.pytorch.LightningModule.__init__(self)
+
+ # Check consistency
+ check_consistency(use_lt, bool)
+ check_consistency(problem, BaseProblem)
+ for condition in problem.conditions.values():
+ check_consistency(condition, self.accepted_conditions_types)
+
+ # Initialize the solver components
+ self._pina_problem = problem
+ self._use_lt = use_lt
+
+ # Manage InverseProblem parameters if needed
+ if isinstance(self.problem, InverseProblem):
+ self._params = self.problem.unknown_parameters
+ self._clamp_params = self._clamp_inverse_problem_params
+ else:
+ self._params = None
+ self._clamp_params = lambda: None
+
+ # Labelize the forward method if using LabelTensors
+ if self.use_lt:
+ self.forward = labelize_forward(
+ forward=self.forward,
+ input_variables=problem.input_variables,
+ output_variables=problem.output_variables,
+ )
+
+ def reset(self):
+ """
+ Reset the internal solver state, clearing the stored problem, models,
+ optimizers and schedulers.
+ """
+ self._pina_problem = None
+ self._pina_models = None
+ self._pina_optimizers = None
+ self._pina_schedulers = None
+
+ def _clamp_inverse_problem_params(self):
+ """
+ Clamp the unknown parameters of an inverse problem. Each unknown
+ parameter is constrained to lie within the corresponding bounds defined
+ by the inverse problem parameter domain.
+ """
+ for v in self._params:
+ self._params[v].data.clamp_(
+ self.problem.unknown_parameter_domain.range[v][0],
+ self.problem.unknown_parameter_domain.range[v][1],
+ )
+
+ def _init_weighting_and_loss(self, weighting=None, loss=None):
+ """
+ Initialize the weighting strategy and loss function.
+
+ :param BaseWeighting weighting: The weighting strategy used to combine
+ condition losses. If ``None``, no weighting is applied. Default is
+ ``None``.
+ :param loss: The loss function used to compute residual losses.
+ If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``.
+ :type loss: torch.nn.Module | BaseDualLoss
+ :raises ValueError: If ``weighting`` is not an instance of
+ :class:`~pina.weighting.base_weighting.BaseWeighting`.
+ :raises ValueError: If ``loss`` is not a valid PyTorch loss or
+ :class:`~pina.loss.base_dual_loss.BaseDualLoss`.
+ """
+ # If no weighting schema is provided, use a default no-weighting schema
+ if weighting is None:
+ weighting = _NoWeighting()
+
+ # Set default loss function to MSE if not provided
+ if loss is None:
+ loss = torch.nn.MSELoss()
+
+ # Check consistency
+ check_consistency(weighting, BaseWeighting)
+ check_consistency(loss, (BaseDualLoss, torch.nn.modules.loss._Loss))
+
+ # Store the weighting and loss function for use in the solver
+ self._pina_weighting = weighting
+ weighting._solver = self
+ self._loss_fn = loss
+ self._reduction = getattr(loss, "reduction", "mean")
+ if hasattr(self._loss_fn, "reduction"):
+ self._loss_fn.reduction = "none"
+
+ def _init_solver_components(
+ self,
+ models,
+ optimizers=None,
+ schedulers=None,
+ ):
+ """
+ Initialize the solver models, optimizers and schedulers.
+
+ :param models: The model or list of models used by the solver.
+ :type models: torch.nn.Module | list[torch.nn.Module]
+ :param optimizers: The optimizer or list of optimizers used by the
+ solver. If ``None``, the ``torch.optim.Adam`` optimizer with a
+ learning rate of ``0.001`` is used for each model.
+ Default is ``None``.
+ :type optimizers: TorchOptimizer | list[TorchOptimizer]
+ :param schedulers: The scheduler or list of schedulers used by the
+ solver. If ``None``, the ``torch.optim.lr_scheduler.ConstantLR``
+ scheduler with a factor of ``1.0`` is used for each model.
+ Default is ``None``.
+ :type schedulers: TorchScheduler | list[TorchScheduler]
+ :raises ValueError: If ``models`` are not instances of
+ :class:`torch.nn.Module`.
+ :raises ValueError: If ``optimizers`` are not instances of
+ :class:`~pina.optim.torch_optimizer.TorchOptimizer`.
+ :raises ValueError: If ``schedulers`` are not instances of
+ :class:`~pina.optim.torch_scheduler.TorchScheduler`.
+ :raises ValueError: If the number of optimizers does not match that of
+ models.
+ :raises ValueError: If the number of schedulers does not match that of
+ models.
+ """
+
+ # Helper function to map single items to lists if needed
+ _to_list = lambda x: [x] if not isinstance(x, (list, tuple)) else x
+
+ # Map models to list if a single model is provided
+ models = _to_list(models)
+
+ # Set default optimizers to Adam if not provided
+ if optimizers is None:
+ optimizers = [
+ TorchOptimizer(torch.optim.Adam, lr=0.001)
+ for _ in range(len(models))
+ ]
+
+ # Set default schedulers to ConstantLR if not provided
+ if schedulers is None:
+ schedulers = [
+ TorchScheduler(torch.optim.lr_scheduler.ConstantLR, factor=1.0)
+ for _ in range(len(models))
+ ]
+
+ # Map optimizers and schedulers to lists if single items are provided
+ optimizers = _to_list(optimizers)
+ schedulers = _to_list(schedulers)
+
+ # Check consistency
+ check_consistency(optimizers, TorchOptimizer)
+ check_consistency(schedulers, TorchScheduler)
+ check_consistency(models, torch.nn.Module)
+
+ # Check that the number of optimizers matches the number of models
+ if len(optimizers) != len(models):
+ raise ValueError(
+ "You must define one optimizer for each model."
+ f"Got {len(models)} models, and {len(optimizers)} optimizers."
+ )
+
+ # Check that the number of schedulers matches the number of models
+ if len(schedulers) != len(models):
+ raise ValueError(
+ "You must define one scheduler for each model."
+ f"Got {len(models)} models, and {len(schedulers)} schedulers."
+ )
+
+ # Initialize the solver components
+ self._pina_models = torch.nn.ModuleList(models)
+ self._pina_optimizers = optimizers
+ self._pina_schedulers = schedulers
+
+ def training_step(self, batch, batch_idx):
+ """
+ Solver training step.
+
+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
+ tuple containing a condition name and a dictionary of points.
+ :param int batch_idx: The index of the current batch.
+ :return: The loss of the training step.
+ :rtype: torch.Tensor
+ """
+ loss = self.batch_evaluation_step(batch=batch, batch_idx=batch_idx)
+ self.log(
+ name="train_loss",
+ value=loss.item(),
+ batch_size=self.get_batch_size(batch),
+ **self.trainer.logging_kwargs,
+ )
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ """
+ Solver validation step.
+
+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
+ tuple containing a condition name and a dictionary of points.
+ :param int batch_idx: The index of the current batch.
+ :return: The loss of the training step.
+ :rtype: torch.Tensor
+ """
+ loss = self.batch_evaluation_step(batch=batch, batch_idx=batch_idx)
+ self.log(
+ name="val_loss",
+ value=loss.item(),
+ batch_size=self.get_batch_size(batch),
+ **self.trainer.logging_kwargs,
+ )
+ return loss
+
+ def test_step(self, batch, batch_idx):
+ """
+ Solver test step.
+
+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
+ tuple containing a condition name and a dictionary of points.
+ :param int batch_idx: The index of the current batch.
+ :return: The loss of the training step.
+ :rtype: torch.Tensor
+ """
+ loss = self.batch_evaluation_step(batch=batch, batch_idx=batch_idx)
+ self.log(
+ name="test_loss",
+ value=loss.item(),
+ batch_size=self.get_batch_size(batch),
+ **self.trainer.logging_kwargs,
+ )
+ return loss
+
+ def _compute_condition_loss(self, condition, data, batch_idx):
+ """
+ Compute the scalar loss for a given condition and its data.
+
+ :param BaseCondition condition: The condition for which to compute the
+ loss.
+ :param dict data: The data corresponding to the condition.
+ :param int batch_idx: The index of the current batch.
+ :return: The scalar loss for the condition.
+ :rtype: torch.Tensor
+ """
+ # Clone the input tensor if it exists to avoid in-place modifications
+ if "input" in data and hasattr(data["input"], "clone"):
+ data = dict(data)
+ data["input"] = data["input"].clone()
+
+ # Prepare condition data, e.g. by enabling gradient for regularizations
+ data = self._prepare_condition_data(data=data)
+
+ # Compute and store the residual tensor for the condition
+ self.residual_tensor = condition.evaluate(data, self)
+
+ # Retrieve condition name for more complex weighting schemes
+ condition_name = condition.name if hasattr(condition, "name") else None
+
+ # Compute the tensor loss from the residual tensor
+ condition_tensor_loss = self._loss_from_residual(condition_name)
+
+ # Optional regularization hook, e.g gradient-enhanced or residual-based
+ condition_tensor_loss = self._regularize_condition_loss(
+ condition_tensor_loss=condition_tensor_loss,
+ condition_name=condition_name,
+ data=data,
+ batch_idx=batch_idx,
+ )
+
+ # Compute the scalar loss from the tensor loss and return it
+ condition_scalar_loss = self._apply_reduction(condition_tensor_loss)
+
+ return condition_scalar_loss
+
+ def _prepare_condition_data(self, data):
+ """
+ Prepare the condition data for loss computation. This method can be
+ overridden by mixins to implement specific data preparation steps, such
+ as enabling gradient tracking for inputs in gradient-enhanced solvers.
+
+ :param dict data: The original condition data.
+ :return: The prepared condition data.
+ :rtype: dict
+ """
+ return data
+
+ def _regularize_condition_loss(
+ self,
+ condition_tensor_loss,
+ condition_name,
+ data,
+ batch_idx,
+ ):
+ """
+ Regularize the condition loss if needed. This method can be overridden
+ by mixins to implement specific regularization strategies, such as
+ adding a gradient penalty in gradient-enhanced solvers or applying
+ residual-based attention.
+
+ :param condition_tensor_loss: The original tensor loss for the
+ condition.
+ :type condition_tensor_loss: torch.Tensor | LabelTensor
+ :param str condition_name: The name of the condition.
+ :param dict data: The data corresponding to the condition.
+ :param int batch_idx: The index of the current batch.
+ :return: The regularized tensor loss for the condition.
+ :rtype: torch.Tensor | LabelTensor
+ """
+ return condition_tensor_loss
+
+ def _loss_from_residual(self, condition_name=None):
+ """
+ Compute the tensor loss from the residual tensor.
+
+ :param str condition_name: The name of the condition.
+ :return: The tensor loss computed from the residual tensor.
+ :rtype: torch.Tensor | LabelTensor
+ """
+ # Compute the loss tensor and appply reduction
+ return self._loss_fn(
+ self.residual_tensor, torch.zeros_like(self.residual_tensor)
+ )
+
+ def _apply_reduction(self, value):
+ """
+ Apply the specified reduction to the loss tensor.
+
+ :param value: The loss tensor to reduce.
+ :type value: torch.Tensor | LabelTensor
+ :return: The reduced loss.
+ :rtype: torch.Tensor | LabelTensor
+ """
+ # Get the reduction function based on the specified reduction type
+ reduction_fn = self._AVAILABLE_REDUCTIONS.get(self._reduction)
+
+ # If the reduction type is not supported, raise an error
+ if reduction_fn is None:
+ raise ValueError(
+ f"Unsupported reduction '{self._reduction}'. "
+ f"Available options include {self._AVAILABLE_REDUCTIONS.keys()}"
+ )
+
+ return reduction_fn(value)
+
+ @staticmethod
+ def get_batch_size(batch):
+ """
+ Get the batch size.
+
+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
+ tuple containing a condition name and a dictionary of points.
+ :return: The size of the batch.
+ :rtype: int
+ """
+ return sum(len(data[1]["input"]) for data in batch)
+
+ @property
+ def problem(self):
+ """
+ The problem instance.
+
+ :return: The problem instance.
+ :rtype: :class:`~pina.problem.base_problem.BaseProblem`
+ """
+ return self._pina_problem
+
+ @property
+ def use_lt(self):
+ """
+ Using LabelTensors as input during training.
+
+ :return: The use_lt attribute.
+ :rtype: bool
+ """
+ return self._use_lt
+
+ @property
+ def weighting(self):
+ """
+ The weighting schema used by the solver.
+
+ :return: The weighting schema used by the solver.
+ :rtype: :class:`~pina.weighting.base_weighting.BaseWeighting`
+ """
+ return self._pina_weighting
+
+ @property
+ def loss(self):
+ """
+ The element-wise loss module used by the solver.
+
+ :return: The element-wise loss module used by the solver.
+ :rtype: torch.nn.Module
+ """
+ return self._loss_fn
diff --git a/pina/_src/solver/causal_physics_informed_single_model_solver.py b/pina/_src/solver/causal_physics_informed_single_model_solver.py
new file mode 100644
index 000000000..db243e020
--- /dev/null
+++ b/pina/_src/solver/causal_physics_informed_single_model_solver.py
@@ -0,0 +1,336 @@
+"""Module for the causal physics-informed single-model solver class."""
+
+import torch
+from pina._src.condition.input_equation_condition import InputEquationCondition
+from pina._src.solver.mixin.physics_informed_mixin import PhysicsInformedMixin
+from pina._src.condition.input_target_condition import InputTargetCondition
+from pina._src.core.utils import check_consistency, check_positive_integer
+from pina._src.problem.time_dependent_problem import TimeDependentProblem
+from pina._src.solver.single_model_solver import SingleModelSolver
+from pina._src.core.label_tensor import LabelTensor
+from pina._src.condition.domain_equation_condition import (
+ DomainEquationCondition,
+)
+
+
+class CausalPhysicsInformedSingleModelSolver(
+ PhysicsInformedMixin, SingleModelSolver
+):
+ r"""
+ Single-model solver for causal physics-informed learning problems.
+
+ This solver approximates the solution of a time-dependent differential
+ problem using a single model and a causality-aware training objective. It is
+ intended for problems whose conditions include equation residuals and
+ boundary residuals evaluated across ordered time snapshots. It can be used
+ only for forward problems, due to the causal nature of the training
+ objective.
+
+ Given a model :math:`\mathcal{M}`, the predicted solution is
+
+ .. math::
+
+ \hat{\mathbf{u}}(\mathbf{x}, t) = \mathcal{M}(\mathbf{x}, t).
+
+ The solver minimizes a causal residual loss that weights each time snapshot
+ according to the residuals accumulated at previous times. For a time
+ dependent problem with governing equation operator :math:`\mathcal{A}` in
+ the domain :math:`\Omega` and boundary operator :math:`\mathcal{B}` on the
+ boundary :math:`\partial\Omega`, the objective can be written as
+
+ .. math::
+
+ \mathcal{L}_{\mathrm{problem}} = \frac{1}{N_t} \sum_{i=1}^{N_t}
+ \omega_i \mathcal{L}_r(t_i),
+
+ where the residual loss at time :math:`t` is
+
+ .. math::
+
+ \mathcal{L}_r(t) = \frac{1}{N_{\Omega}} \sum_{j=1}^{N_{\Omega}}
+ \mathcal{L}\left( \mathcal{A}[\hat{\mathbf{u}}](\mathbf{x}_j, t) \right)
+ + \frac{1}{N_{\partial\Omega}} \sum_{j=1}^{N_{\partial\Omega}}
+ \mathcal{L} \left( \mathcal{B}[\hat{\mathbf{u}}](\mathbf{x}_j, t)
+ \right).
+
+ The causal weights are defined as
+
+ .. math::
+
+ \omega_i = \exp \left( -\epsilon \sum_{k=1}^{i-1} \mathcal{L}_r(t_k)
+ \right),
+
+ where :math:`\epsilon` is a hyperparameter controlling the strength of the
+ causal weighting, and :math:`\mathcal{L}` is the selected loss function,
+ typically the mean squared error.
+
+ .. seealso::
+
+ **Original reference**: Wang, S., Sankaran, S., & Perdikaris, P. (2024).
+ *Respecting causality for training physics-informed neural networks.*
+ Computer Methods in Applied Mechanics and Engineering, 421, 116813.
+ DOI: `10.1016/j.cma.2024.116813
+ `_.
+
+ .. note::
+
+ This solver is compatible only with problems inheriting from
+ :class:`~pina.problem.time_dependent_problem.TimeDependentProblem`.
+ """
+
+ # Accepted conditions types for this solver
+ accepted_conditions_types = (
+ InputTargetCondition,
+ InputEquationCondition,
+ DomainEquationCondition,
+ )
+
+ def __init__(
+ self,
+ problem,
+ model,
+ optimizer=None,
+ scheduler=None,
+ weighting=None,
+ loss=None,
+ eps=100,
+ n_steps=10,
+ regularized_conditions=None,
+ ):
+ """
+ Initialization of the :class:`CausalPhysicsInformedSingleModelSolver`
+ class.
+
+ :param BaseProblem problem: The problem to be solved.
+ :param torch.nn.Module model: The model used by the solver.
+ :param TorchOptimizer optimizer: The optimizer used by the solver.
+ If ``None``, the ``torch.optim.Adam`` optimizer with a learning rate
+ of ``0.001`` is used. Default is ``None``.
+ :param TorchScheduler scheduler: The scheduler used by the solver.
+ If ``None``, the ``torch.optim.lr_scheduler.ConstantLR`` scheduler
+ with a factor of ``1.0`` is used. Default is ``None``.
+ :param BaseWeighting weighting: The weighting strategy used to combine
+ condition losses. If ``None``, no weighting is applied. Default is
+ ``None``.
+ :param loss: The loss function used to compute residual losses.
+ If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``.
+ :param eps: The exponential decay parameter. Default is ``100``.
+ :type eps: float | int
+ :param int n_steps: The number of equispaced temporal steps used to
+ compute the causal loss. Default is ``10``.
+ :param regularized_conditions: The names of the conditions that should
+ receive causal regularization. Default is ``None``.
+ :raises ValueError: If the problem is not time-dependent.
+ :raises ValueError: If the user does not specify any regularized
+ conditions.
+ :raises ValueError: If any of the specified ``regularized_conditions``
+ are not present in the ``problem``'s conditions.
+ :raises ValueError: If ``eps`` is not a float or int.
+ :raises ValueError: If ``n_steps`` is not a positive integer.
+ """
+ # Ensure the problem is time-dependent
+ if not isinstance(problem, TimeDependentProblem):
+ raise ValueError(
+ "Causal physics-informed solvers require the problem to be an "
+ f"instance of TimeDependentProblem. Got {type(problem)}."
+ )
+
+ # Ensure the user specified valid regularized conditions
+ if regularized_conditions is None:
+ raise ValueError(
+ "Causal physics-informed solvers require the user to specify "
+ "the conditions that should receive causal regularization. "
+ "Apply causal regularization only to time-dependent conditions."
+ )
+
+ # Check consistency
+ check_consistency(eps, (int, float))
+ check_consistency(regularized_conditions, str)
+ check_positive_integer(n_steps, strict=True)
+
+ # Map conditions to list if a single condition is provided
+ if not isinstance(regularized_conditions, (list, tuple)):
+ regularized_conditions = [regularized_conditions]
+
+ # Ensure that all regularized conditions are present in the problem
+ problem_conditions = set(problem.conditions.keys())
+ for condition in regularized_conditions:
+ if condition not in problem_conditions:
+ raise ValueError(
+ f"Condition '{condition}' is not present in the problem."
+ )
+
+ # Initialize the parent class
+ SingleModelSolver.__init__(
+ self,
+ problem=problem,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ weighting=weighting,
+ loss=loss,
+ use_lt=True,
+ )
+
+ # Initialize the causal regularization parameters
+ self.eps = eps
+ self.n_steps = n_steps
+ self.regularized_conditions = regularized_conditions
+
+ def _compute_condition_loss(self, condition, data, batch_idx):
+ """
+ Compute the scalar loss for a given condition and its data.
+
+ :param BaseCondition condition: The condition for which to compute the
+ loss.
+ :param dict data: The data corresponding to the condition.
+ :param int batch_idx: The index of the current batch.
+ :return: The scalar loss for the condition.
+ :rtype: torch.Tensor
+ """
+ # If the condition is not regularized, or is a supervised (target)
+ # condition, use the standard loss computation
+ if condition.name not in self.regularized_conditions or isinstance(
+ condition, InputTargetCondition
+ ):
+ return super()._compute_condition_loss(condition, data, batch_idx)
+
+ # Clone the input tensor if it exists to avoid in-place modifications
+ if "input" in data and hasattr(data["input"], "clone"):
+ data = dict(data)
+ data["input"] = data["input"].clone()
+
+ # Prepare condition data, e.g. by enabling gradient for regularizations
+ data = self._prepare_condition_data(data=data)
+
+ # Extract the temporal domain
+ time_domain = self.problem.temporal_domain
+
+ # Define the time steps for the causal loss computation
+ if time_domain.range:
+ time_steps = torch.linspace(
+ time_domain.range[self.temporal_variable][0],
+ time_domain.range[self.temporal_variable][1],
+ self.n_steps,
+ device=data["input"].device,
+ dtype=data["input"].dtype,
+ )
+
+ # If no range is defined, use the unique temporal value
+ else:
+ time_steps = torch.tensor(
+ [time_domain.fixed[self.temporal_variable]],
+ device=data["input"].device,
+ dtype=data["input"].dtype,
+ )
+
+ # Initialize the list to store the loss for each time step
+ time_loss = []
+
+ # Iterate over the time steps
+ for step in time_steps:
+
+ # Append the temporal variable to the spatial input tensor
+ spatial_pts = data["input"].extract(self.spatial_variables)
+ time_pts = LabelTensor(
+ torch.ones(spatial_pts.shape[0], 1, device=spatial_pts.device)
+ * step,
+ labels=[self.temporal_variable],
+ )
+ pts = {
+ "input": LabelTensor.cat(
+ [spatial_pts, time_pts], dim=1
+ ).requires_grad_(True)
+ }
+
+ # Compute and store the residual tensor for the condition
+ self.residual_tensor = condition.evaluate(pts, self)
+
+ # Retrieve condition name for more complex weighting schemes
+ condition_name = (
+ condition.name if hasattr(condition, "name") else None
+ )
+
+ # Compute the tensor loss from the residual tensor
+ condition_tensor_loss = self._loss_from_residual(condition_name)
+
+ # Optional regularization hook
+ condition_tensor_loss = self._regularize_condition_loss(
+ condition_tensor_loss=condition_tensor_loss,
+ condition_name=condition_name,
+ data=data,
+ batch_idx=batch_idx,
+ )
+
+ # Append the loss for the current time step to the list
+ time_loss.append(condition_tensor_loss)
+
+ # Compute the time-adaptive weights for the causal loss
+ time_loss = torch.stack(time_loss)
+ with torch.no_grad():
+ weights = self._compute_weights(time_loss)
+
+ # Compute the scalar loss from the tensor loss and return it
+ condition_scalar_loss = self._apply_reduction(weights * time_loss)
+
+ return condition_scalar_loss
+
+ def _compute_weights(self, loss):
+ """
+ Compute the temporal adaptive weights for the causal loss based on the
+ cumulative loss.
+
+ :param LabelTensor loss: The physics loss values.
+ :return: The computed weights for the physics loss.
+ :rtype: LabelTensor
+ """
+ # Compute the cumulative loss and apply exponential decay
+ cumulative_loss = torch.cumsum(loss, dim=0)
+ return torch.exp(-self.eps * cumulative_loss)
+
+ @property
+ def temporal_variable(self):
+ """
+ The temporal variable of the problem.
+
+ :return: The temporal variable name.
+ :rtype: str
+ :raises ValueError: If the problem does not have exactly one temporal
+ variable.
+ """
+ # Extract the temporal variable from the problem
+ temporal_variables = self.problem.temporal_variables
+
+ # Raise error if there is not exactly one temporal variable
+ if len(temporal_variables) != 1:
+ raise ValueError(
+ "Causal physics-informed solvers require exactly one temporal "
+ f"variable. Got {temporal_variables}."
+ )
+
+ return temporal_variables[0]
+
+ @property
+ def spatial_variables(self):
+ """
+ The spatial variables of the problem.
+
+ :return: The spatial variable names.
+ :rtype: list[str]
+ :raises ValueError: If the problem does not have any spatial variables.
+ """
+ # Determine the spatial variables by excluding the temporal variable
+ spatial_variables = [
+ v
+ for v in self.problem.input_variables
+ if v != self.temporal_variable
+ ]
+
+ # Raise error if there are no spatial variables left
+ if not spatial_variables:
+ raise ValueError(
+ "Causal physics-informed solvers require at least one spatial "
+ "variable in addition to time."
+ )
+
+ return spatial_variables
diff --git a/pina/_src/solver/competitive_physics_informed_solver.py b/pina/_src/solver/competitive_physics_informed_solver.py
new file mode 100644
index 000000000..70ed77030
--- /dev/null
+++ b/pina/_src/solver/competitive_physics_informed_solver.py
@@ -0,0 +1,306 @@
+"""Module for the competitive physics-informed multi-model solver."""
+
+import copy
+from pina._src.solver.mixin.physics_informed_mixin import PhysicsInformedMixin
+from pina._src.condition.input_equation_condition import InputEquationCondition
+from pina._src.condition.input_target_condition import InputTargetCondition
+from pina._src.solver.multi_model_solver import MultiModelSolver
+from pina._src.condition.domain_equation_condition import (
+ DomainEquationCondition,
+)
+
+
+class CompetitivePhysicsInformedSolver(PhysicsInformedMixin, MultiModelSolver):
+ r"""
+ Multi-model solver for competitive physics-informed learning problems.
+
+ This solver approximates the solution of a differential problem using a
+ trainable model together with a discriminator network. It is intended for
+ problems whose conditions may include supervised data, equation residuals
+ evaluated on input points, and equation residuals sampled from domains.
+
+ Given a model :math:`\mathcal{M}`, the predicted solution is
+
+ .. math::
+
+ \hat{\mathbf{u}}(\mathbf{x}) = \mathcal{M}(\mathbf{x}).
+
+ The discriminator :math:`D` assigns pointwise weights to the residuals,
+ encouraging the model to focus on regions where the approximation performs
+ poorly. The model parameters are optimized by minimizing the loss, while the
+ discriminator parameters are optimized by maximizing it.
+
+ For a problem with governing equation operator :math:`\mathcal{A}` in the
+ domain :math:`\Omega` and boundary operator :math:`\mathcal{B}` on the
+ boundary :math:`\partial\Omega`, the competitive objective can be written as
+
+ .. math::
+
+ \mathcal{L}_{\mathrm{problem}} = \frac{1}{N_{\Omega}}
+ \sum_{i=1}^{N_{\Omega}} \mathcal{L}
+ \left(D(\mathbf{x}_i)\mathcal{A}[\hat{\mathbf{u}}](\mathbf{x}_i)\right)
+ +\frac{1}{N_{\partial\Omega}} \sum_{i=1}^{N_{\partial\Omega}}
+ \mathcal{L}
+ \left(D(\mathbf{x}_i)\mathcal{B}[\hat{\mathbf{u}}](\mathbf{x}_i)\right),
+
+ where :math:`D` is the discriminator network and :math:`\mathcal{L}` is the
+ selected loss function, typically the mean squared error.
+
+ The model and discriminator are trained through a min-max problem:
+
+ .. math::
+
+ \min_{\theta} \max_{\phi} \mathcal{L}_{\mathrm{problem}},
+
+ where :math:`\theta` denotes the model parameters and :math:`\phi` denotes
+ the discriminator parameters.
+
+ .. seealso::
+
+ **Original reference**: Zeng, Q., Kothari, P., Chou, E., & Masi, G.
+ (2022).
+ *Competitive physics informed networks.*
+ International Conference on Learning Representations, ICLR 2022.
+ `OpenReview Preprint `_.
+ """
+
+ # Accepted conditions types for this solver
+ accepted_conditions_types = (
+ InputTargetCondition,
+ InputEquationCondition,
+ DomainEquationCondition,
+ )
+
+ def __init__(
+ self,
+ problem,
+ model,
+ discriminator=None,
+ optimizer_model=None,
+ optimizer_discriminator=None,
+ scheduler_model=None,
+ scheduler_discriminator=None,
+ weighting=None,
+ loss=None,
+ ):
+ """
+ Initialization of the :class:`CompetitivePhysicsInformedSolver` class.
+
+ :param BaseProblem problem: The problem to be solved.
+ :param torch.nn.Module model: The model used by the solver.
+ :param torch.nn.Module discriminator: The discriminator used by the
+ solver. If ``None``, a deep copy of the model is used as
+ discriminator. Default is ``None``.
+ :param TorchOptimizer optimizer_model: The optimizer of the main model.
+ If ``None``, the ``torch.optim.Adam`` optimizer with a learning rate
+ of ``0.001`` is used. Default is ``None``.
+ :param TorchOptimizer optimizer_discriminator: The optimizer of the
+ discriminator. If ``None``, the ``torch.optim.Adam`` optimizer with
+ a learning rate of ``0.001`` is used. Default is ``None``.
+ :param TorchScheduler scheduler_model: The scheduler of the main model.
+ If ``None``, the ``torch.optim.lr_scheduler.ConstantLR`` scheduler
+ with a factor of ``1.0`` is used. Default is ``None``.
+ :param TorchScheduler scheduler_discriminator: The scheduler of the
+ discriminator.
+ If ``None``, the ``torch.optim.lr_scheduler.ConstantLR`` scheduler
+ with a factor of ``1.0`` is used. Default is ``None``.
+ :param BaseWeighting weighting: The weighting strategy used to combine
+ condition losses. If ``None``, no weighting is applied. Default is
+ ``None``.
+ :param loss: The loss function used to compute residual losses.
+ If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``.
+ :raises ValueError: If ``weight_function`` is not a ``torch.nn.Module``.
+ :raises ValueError: If not all domains have been discretised.
+ """
+ # Initialize the discriminator if not provided
+ if discriminator is None:
+ discriminator = copy.deepcopy(model)
+
+ # Prepare optimizers
+ optimizers = (
+ [optimizer_model, optimizer_discriminator]
+ if any(
+ o is not None
+ for o in (optimizer_model, optimizer_discriminator)
+ )
+ else None
+ )
+
+ # Prepare schedulers
+ schedulers = (
+ [scheduler_model, scheduler_discriminator]
+ if any(
+ s is not None
+ for s in (scheduler_model, scheduler_discriminator)
+ )
+ else None
+ )
+
+ # Initialize the base solver
+ MultiModelSolver.__init__(
+ self,
+ problem=problem,
+ models=[model, discriminator],
+ optimizers=optimizers,
+ schedulers=schedulers,
+ weighting=weighting,
+ loss=loss,
+ use_lt=True,
+ )
+
+ def training_step(self, batch, batch_idx):
+ """
+ Solver training step.
+
+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
+ tuple containing a condition name and a dictionary of points.
+ :param int batch_idx: The index of the current batch.
+ :return: The loss of the training step.
+ :rtype: torch.Tensor
+ """
+ # Zero the gradients of the model optimizer and compute the loss
+ self.optimizer_model.instance.zero_grad()
+ loss = self.batch_evaluation_step(batch, batch_idx)
+
+ # Perform the backward pass and complete a step for the model
+ self.manual_backward(loss)
+ self.optimizer_model.instance.step()
+ self.scheduler_model.instance.step()
+
+ # Zero the gradients of the discriminator optimizer and compute the loss
+ self.optimizer_discriminator.instance.zero_grad()
+ loss = self.batch_evaluation_step(batch, batch_idx)
+
+ # Perform the backward pass and complete a step for the discriminator
+ self.manual_backward(-loss)
+ self.optimizer_discriminator.instance.step()
+ self.scheduler_discriminator.instance.step()
+
+ # Log the training loss
+ self.log(
+ name="train_loss",
+ value=loss.item(),
+ batch_size=self.get_batch_size(batch),
+ **self.trainer.logging_kwargs,
+ )
+
+ return loss
+
+ def forward(self, x):
+ """
+ Forward pass through the model.
+
+ :param x: The input data.
+ :type x: torch.Tensor | LabelTensor | Data | Graph
+ :return: The output of the model.
+ :rtype: torch.Tensor | LabelTensor | Data | Graph
+ """
+ return self.model(x)
+
+ def _compute_condition_loss(self, condition, data, batch_idx):
+ """
+ Compute the scalar loss for a given condition and its data.
+
+ :param BaseCondition condition: The condition for which to compute the
+ loss.
+ :param dict data: The data corresponding to the condition.
+ :param int batch_idx: The index of the current batch.
+ :return: The scalar loss for the condition.
+ :rtype: torch.Tensor
+ """
+ # Clone the input tensor if it exists to avoid in-place modifications
+ if "input" in data and hasattr(data["input"], "clone"):
+ data = dict(data)
+ data["input"] = data["input"].clone()
+
+ # Prepare condition data, e.g. by enabling gradient for regularizations
+ data = self._prepare_condition_data(data=data)
+
+ # Compute and store the residual tensor for the condition
+ self.residual_tensor = condition.evaluate(data, self)
+
+ # Compute the discriminator bets for the current condition
+ discriminator_input = data["input"][self.problem.input_variables]
+ discriminator_bets = self.discriminator(discriminator_input)
+
+ # Weight the residual tensor using the discriminator bets
+ self.residual_tensor = self.residual_tensor * discriminator_bets
+
+ # Retrieve condition name for more complex weighting schemes
+ condition_name = condition.name if hasattr(condition, "name") else None
+
+ # Compute the tensor loss from the residual tensor
+ condition_tensor_loss = self._loss_from_residual(condition_name)
+
+ # Optional regularization hook, e.g gradient-enhanced or residual-based
+ condition_tensor_loss = self._regularize_condition_loss(
+ condition_tensor_loss=condition_tensor_loss,
+ condition_name=condition_name,
+ data=data,
+ batch_idx=batch_idx,
+ )
+
+ # Compute the scalar loss from the tensor loss and return it
+ condition_scalar_loss = self._apply_reduction(condition_tensor_loss)
+
+ return condition_scalar_loss
+
+ @property
+ def model(self):
+ """
+ The single model used by the solver.
+
+ :return: The single model used by the solver.
+ :rtype: torch.nn.Module
+ """
+ return self._pina_models[0]
+
+ @property
+ def discriminator(self):
+ """
+ The discriminator used by the solver.
+
+ :return: The discriminator used by the solver.
+ :rtype: torch.nn.Module
+ """
+ return self._pina_models[1]
+
+ @property
+ def optimizer_model(self):
+ """
+ The optimizer for the model used by the solver.
+
+ :return: The optimizer for the model used by the solver.
+ :rtype: TorchOptimizer
+ """
+ return self.optimizers[0]
+
+ @property
+ def optimizer_discriminator(self):
+ """
+ The optimizer for the discriminator used by the solver.
+
+ :return: The optimizer for the discriminator used by the solver.
+ :rtype: TorchOptimizer
+ """
+ return self.optimizers[1]
+
+ @property
+ def scheduler_model(self):
+ """
+ The scheduler for the model used by the solver.
+
+ :return: The scheduler for the model used by the solver.
+ :rtype: TorchScheduler
+ """
+ return self.schedulers[0]
+
+ @property
+ def scheduler_discriminator(self):
+ """
+ The scheduler for the discriminator used by the solver.
+
+ :return: The scheduler for the discriminator used by the solver.
+ :rtype: TorchScheduler
+ """
+ return self.schedulers[1]
diff --git a/pina/_src/solver/ensemble_solver.py b/pina/_src/solver/ensemble_solver.py
new file mode 100644
index 000000000..d1a78a870
--- /dev/null
+++ b/pina/_src/solver/ensemble_solver.py
@@ -0,0 +1,82 @@
+"""Module for the ensemble solver class."""
+
+from pina._src.solver.mixin.ensemble_mixin import EnsembleMixin
+from pina._src.solver.base_solver import BaseSolver
+from pina._src.solver.mixin.manual_optimization_mixin import (
+ ManualOptimizationMixin,
+)
+from pina._src.solver.mixin.condition_aggregator_mixin import (
+ ConditionAggregatorMixin,
+)
+
+
+class EnsembleSolver(
+ ManualOptimizationMixin,
+ EnsembleMixin,
+ ConditionAggregatorMixin,
+ BaseSolver,
+):
+ """
+ Base class for implementing ensemble-model solvers.
+
+ This class provides the standard starting point for solvers based on an
+ ensemble of models. It combines the shared solver machinery from
+ :class:`~pina._src.solver.base_solver.BaseSolver` with ensemble-model
+ handling, manual optimization, and condition-wise loss aggregation.
+
+ Subclasses can inherit from this class to implement solver-specific
+ behavior while reusing the common logic for model registration, optimizer
+ and scheduler setup, manual optimization, loss evaluation, weighting, and
+ aggregation across problem conditions.
+ """
+
+ def __init__(
+ self,
+ problem,
+ models,
+ optimizers=None,
+ schedulers=None,
+ weighting=None,
+ loss=None,
+ use_lt=True,
+ ):
+ """
+ Initialization of the :class:`EnsembleSolver` class.
+
+ :param BaseProblem problem: The problem to be solved.
+ :param models: The model or list of models used by the solver.
+ :type models: torch.nn.Module | list[torch.nn.Module]
+ :param optimizers: The optimizer or list of optimizers used by the
+ solver. If ``None``, the ``torch.optim.Adam`` optimizer with a
+ learning rate of ``0.001`` is used for each model.
+ Default is ``None``.
+ :type optimizers: TorchOptimizer | list[TorchOptimizer]
+ :param schedulers: The scheduler or list of schedulers used by the
+ solver. If ``None``, the ``torch.optim.lr_scheduler.ConstantLR``
+ scheduler with a factor of ``1.0`` is used for each model.
+ Default is ``None``.
+ :type schedulers: TorchScheduler | list[TorchScheduler]
+ :param BaseWeighting weighting: The weighting strategy used to combine
+ condition losses. If ``None``, no weighting is applied. Default is
+ ``None``.
+ :param loss: The loss function used to compute residual losses.
+ If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``.
+ :param bool use_lt: If ``True``, the solver uses LabelTensors as input.
+ Default is ``True``.
+ """
+
+ # Initialize the base solver
+ BaseSolver.__init__(self, problem=problem, use_lt=use_lt)
+
+ # Initialize the components of the solver
+ self._init_solver_components(
+ models=models,
+ optimizers=optimizers,
+ schedulers=schedulers,
+ )
+
+ # Initialize the weighting scheme for the conditions and the loss
+ self._init_weighting_and_loss(weighting=weighting, loss=loss)
+
+ # Activate manual optimization
+ self._init_manual_optimization()
diff --git a/pina/_src/solver/gradient_physics_informed_single_model_solver.py b/pina/_src/solver/gradient_physics_informed_single_model_solver.py
new file mode 100644
index 000000000..8991420ce
--- /dev/null
+++ b/pina/_src/solver/gradient_physics_informed_single_model_solver.py
@@ -0,0 +1,124 @@
+"""Module for the gradient physics-informed single-model solver class."""
+
+from pina._src.solver.mixin.physics_informed_mixin import PhysicsInformedMixin
+from pina._src.condition.input_equation_condition import InputEquationCondition
+from pina._src.condition.input_target_condition import InputTargetCondition
+from pina._src.solver.single_model_solver import SingleModelSolver
+from pina._src.solver.mixin.gradient_enhanced_mixin import (
+ GradientEnhancedMixin,
+)
+from pina._src.condition.domain_equation_condition import (
+ DomainEquationCondition,
+)
+
+
+class GradientPhysicsInformedSingleModelSolver(
+ PhysicsInformedMixin, GradientEnhancedMixin, SingleModelSolver
+):
+ r"""
+ Single-model solver for gradient-enhanced physics-informed learning
+ problems.
+
+ This solver approximates the solution of a differential problem using a
+ single model and augments the standard physics-informed objective with
+ gradient-enhanced residual terms. It can be used for both forward and
+ inverse problems.
+
+ Given a model :math:`\mathcal{M}`, the predicted solution is
+
+ .. math::
+
+ \hat{\mathbf{u}}(\mathbf{x}) = \mathcal{M}(\mathbf{x}).
+
+ The solver minimizes both the residuals of the differential operators
+ defining the problem and the gradients of those residuals with respect to
+ the input variables. For a problem with governing equation operator
+ :math:`\mathcal{A}` in the domain :math:`\Omega` and boundary operator
+ :math:`\mathcal{B}` on the boundary :math:`\partial\Omega`, the objective
+ can be written as
+
+ .. math::
+
+ \mathcal{L}_{\mathrm{problem}} = \frac{1}{N_{\Omega}}
+ \sum_{i=1}^{N_{\Omega}} \mathcal{L}
+ \left( \mathcal{A}[\hat{\mathbf{u}}](\mathbf{x}_i) \right)
+ + \frac{1}{N_{\partial\Omega}} \sum_{i=1}^{N_{\partial\Omega}}
+ \mathcal{L} \left( \mathcal{B}[\hat{\mathbf{u}}](\mathbf{x}_i) \right)
+ + \frac{1}{N_{\Omega}} \sum_{i=1}^{N_{\Omega}} \mathcal{L}
+ \left( \nabla_{\mathbf{x}} \mathcal{A}[\hat{\mathbf{u}}](\mathbf{x}_i)
+ \right) + \frac{1}{N_{\partial\Omega}} \sum_{i=1}^{N_{\partial\Omega}}
+ \mathcal{L} \left( \nabla_{\mathbf{x}} \mathcal{B}[\hat{\mathbf{u}}]
+ (\mathbf{x}_i) \right),
+
+ where :math:`\mathcal{L}` is the selected loss function, typically the mean
+ squared error.
+
+ .. seealso::
+
+ **Original reference**: Yu, J., Lu, L., Meng, X., & Karniadakis, G. E.
+ (2022). *Gradient-enhanced physics-informed neural networks for forward
+ and inverse PDE problems.* Computer Methods in Applied Mechanics and
+ Engineering, 393, 114823.
+ DOI: `10.1016/j.cma.2022.114823
+ `_.
+ """
+
+ # Accepted conditions types for this solver
+ accepted_conditions_types = (
+ InputTargetCondition,
+ InputEquationCondition,
+ DomainEquationCondition,
+ )
+
+ def __init__(
+ self,
+ problem,
+ model,
+ optimizer=None,
+ scheduler=None,
+ weighting=None,
+ loss=None,
+ regularization_weight=1.0,
+ regularized_conditions=None,
+ ):
+ """
+ Initialization of the :class:`GradientPhysicsInformedSingleModelSolver`
+ class.
+
+ :param BaseProblem problem: The problem to be solved.
+ :param torch.nn.Module model: The model used by the solver.
+ :param TorchOptimizer optimizer: The optimizer used by the solver.
+ If ``None``, the ``torch.optim.Adam`` optimizer with a learning rate
+ of ``0.001`` is used. Default is ``None``.
+ :param TorchScheduler scheduler: The scheduler used by the solver.
+ If ``None``, the ``torch.optim.lr_scheduler.ConstantLR`` scheduler
+ with a factor of ``1.0`` is used. Default is ``None``.
+ :param BaseWeighting weighting: The weighting strategy used to combine
+ condition losses. If ``None``, no weighting is applied. Default is
+ ``None``.
+ :param loss: The loss function used to compute residual losses.
+ If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``.
+ :param regularization_weight: The weight of the gradient regularization
+ term. Default is ``1.0``.
+ :type regularization_weight: float | int
+ :param regularized_conditions: The names of the conditions that should
+ receive gradient regularization. If ``None``, all conditions are
+ regularized. Default is ``None``.
+ """
+ # Initialize the parent class
+ SingleModelSolver.__init__(
+ self,
+ problem=problem,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ weighting=weighting,
+ loss=loss,
+ use_lt=True,
+ )
+
+ # Initialize the gradient-enhanced components
+ self._init_gradient_enhanced_components(
+ regularization_weight=regularization_weight,
+ regularized_conditions=regularized_conditions,
+ )
diff --git a/pina/_src/solver/mixin/autoregressive_mixin.py b/pina/_src/solver/mixin/autoregressive_mixin.py
new file mode 100644
index 000000000..33259ca94
--- /dev/null
+++ b/pina/_src/solver/mixin/autoregressive_mixin.py
@@ -0,0 +1,186 @@
+"""Module for the autoregressive mixin class."""
+
+import torch
+from pina._src.core.utils import check_consistency
+
+
+class AutoregressiveMixin:
+ """
+ Mixin that enables the autoregressive rollout loss logic by maintaining a
+ running average of step losses and computing adaptive weights for each step
+ based on the cumulative loss. This allows the solver to focus more on steps
+ that are currently underperforming, which can help improve training
+ stability and convergence.
+
+ Designed to be used in combination with any solver inheriting from
+ :class:`~pina._src.solver.base_solver.BaseSolver`.
+ """
+
+ def _init_autoregressive_components(
+ self, eps, reset_weights_at_epoch_start, kwargs
+ ):
+ """
+ Initialize the components related to the autoregressive rollout loss.
+
+ :param eps: The hyperparameter controlling the influence of the
+ cumulative loss on the adaptive weights. Higher values of eps will
+ lead to more aggressive weighting of steps with higher cumulative
+ loss.
+ :type eps: float | int
+ :param bool reset_weights_at_epoch_start: Whether to reset the running
+ average and step count at the start of each epoch. If ``True``, the
+ adaptive weights will be recalibrated at the beginning of each epoch
+ based on the new training dynamics.
+ :param dict kwargs: Additional keyword arguments for preprocessing and
+ postprocessing steps.
+ :raises ValueError: If ``eps`` is not a float or int.
+ :raises ValueError: If ``reset_weights_at_epoch_start`` is not a bool.
+ """
+ # Check consistency
+ check_consistency(eps, (float, int))
+ check_consistency(reset_weights_at_epoch_start, bool)
+
+ # Initialize the components for autoregressive rollout loss
+ self.reset_weights_at_epoch_start = reset_weights_at_epoch_start
+ self.eps = eps
+ self._running_avg = {}
+ self._step_count = {}
+ self._kwargs = kwargs or {}
+
+ def _loss_from_residual(self, condition_name=None):
+ """
+ Compute the tensor loss from the residual tensor.
+
+ :param str condition_name: The name of the condition.
+ :return: The tensor loss computed from the residual tensor.
+ :rtype: torch.Tensor | LabelTensor
+ """
+ # Compute the step losses from the residual tensor
+ step_loss = self._loss_fn(
+ self.residual_tensor, torch.zeros_like(self.residual_tensor)
+ )
+
+ # Retrieve the temporal adaptive weights for the current step losses
+ with torch.no_grad():
+ weights = self._get_weights(condition_name or "default", step_loss)
+
+ return step_loss * weights
+
+ def _get_weights(self, condition_name, step_loss):
+ """
+ Get temporal adaptive weights for each step based on the running average
+ of step losses.
+
+ :param str condition_name: The name of the condition.
+ :param step_loss: The tensor of step losses for the current condition.
+ :type step_loss: torch.Tensor | LabelTensor
+ :return: The tensor of adaptive weights for each step.
+ :rtype: torch.Tensor | LabelTensor
+ """
+ # Use the condition name for tracking the running average and step count
+ key = condition_name or "default"
+ reduce_dims = tuple(range(1, step_loss.dim()))
+ step_loss = step_loss.detach().mean(dim=reduce_dims, keepdim=True)
+
+ # Update the running average and step count for the current condition
+ if key not in self._running_avg:
+ self._running_avg[key] = step_loss.detach().clone()
+ self._step_count[key] = 1
+ else:
+ self._step_count[key] += 1
+ value = step_loss.detach() - self._running_avg[key]
+ self._running_avg[key] += value / self._step_count[key]
+
+ return self._compute_adaptive_weights(self._running_avg[key])
+
+ def _compute_adaptive_weights(self, step_loss):
+ """
+ Compute the adaptive weights for each step based on the cumulative loss.
+
+ :param step_loss: The tensor of step losses for the current condition.
+ :type step_loss: torch.Tensor | LabelTensor
+ :return: The tensor of adaptive weights for each step.
+ :rtype: torch.Tensor | LabelTensor
+ """
+ cumulative_loss = -self.eps * torch.cumsum(step_loss, dim=0)
+ return torch.exp(cumulative_loss)
+
+ def on_train_epoch_start(self):
+ """
+ Clear the running average and step count at the start of each epoch if
+ ``reset_weights_at_epoch_start`` is ``True``.
+ """
+ if self.reset_weights_at_epoch_start:
+ self._running_avg.clear()
+ self._step_count.clear()
+
+ def preprocess_step(self, current_state, **kwargs):
+ """
+ Preprocess the current state before each step.
+
+ :param current_state: The current state tensor.
+ :type current_state: torch.Tensor | LabelTensor
+ :param dict kwargs: Additional keyword arguments for preprocessing.
+ :return: The preprocessed state tensor.
+ :rtype: torch.Tensor | LabelTensor
+ """
+ return current_state
+
+ def postprocess_step(self, predicted_state, **kwargs):
+ """
+ Postprocess the predicted state after each step. If multiple models are
+ used, average the predictions across the model dimension.
+
+ :param predicted_state: The predicted state tensor.
+ :type predicted_state: torch.Tensor | LabelTensor
+ :param dict kwargs: Additional keyword arguments for postprocessing.
+ :return: The postprocessed state tensor.
+ :rtype: torch.Tensor | LabelTensor
+ """
+ return predicted_state
+
+ def predict(self, initial_state, n_steps, **kwargs):
+ """
+ Generate predictions by recursively calling the model's forward.
+
+ :param initial_state: The initial state from which to start prediction.
+ The initial state must be of shape ``[trajectories, 1, *features]``.
+ :type initial_state: torch.Tensor | LabelTensor
+ :param int n_steps: The number of autoregressive steps to predict.
+ :param dict kwargs: Additional keyword arguments.
+ :raises ValueError: If the provided initial_state tensor has less than 3
+ dimensions.
+ :return: The predicted trajectory, including the initial state. It has
+ shape ``[trajectories, n_steps + 1, *features]``, where the first
+ step corresponds to the initial state.
+ :rtype: torch.Tensor | LabelTensor
+ """
+ # Set model to evaluation mode for prediction
+ self.eval()
+
+ # Raise error if the initial_state does not have at least 3 dimensions
+ if initial_state.dim() < 3:
+ raise ValueError(
+ "The provided initial_state tensor must have at least 3"
+ "dimensions: [trajectories, 1, *features]."
+ f" Got shape {initial_state.shape}."
+ )
+
+ # Initialize the list of predictions with the initial state
+ predictions = [initial_state]
+
+ # Disable gradient computation for autoregressive prediction
+ with torch.no_grad():
+
+ # Unroll the autoregressive prediction for n_steps
+ for _ in range(n_steps):
+
+ # Preprocess the current state before the forward pass
+ current_state = self.preprocess_step(predictions[-1], **kwargs)
+ output = self.forward(current_state)
+
+ # Postprocess the predicted state after the forward pass
+ next_state = self.postprocess_step(output, **kwargs)
+ predictions.append(next_state)
+
+ return torch.cat(predictions, dim=1)
diff --git a/pina/_src/solver/mixin/condition_aggregator_mixin.py b/pina/_src/solver/mixin/condition_aggregator_mixin.py
new file mode 100644
index 000000000..ad5b023ac
--- /dev/null
+++ b/pina/_src/solver/mixin/condition_aggregator_mixin.py
@@ -0,0 +1,58 @@
+"""Module for the condition aggregator mixin class."""
+
+import torch
+
+
+class ConditionAggregatorMixin:
+ """
+ Mixin that logs per-condition scalar losses, weights them following the
+ provided weighting scheme, and aggregates them into the total loss.
+
+ Designed to be used in combination with any solver inheriting from
+ :class:`~pina._src.solver.base_solver.BaseSolver`.
+ """
+
+ def batch_evaluation_step(self, batch, batch_idx):
+ """
+ Evaluate and aggregate the losses for all conditions in a batch.
+
+ For each condition in the batch, this method computes the corresponding
+ scalar loss, logs it using the condition name, and combines all
+ condition losses into a single scalar loss through the configured
+ weighting scheme.
+
+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
+ tuple containing a condition name and a dictionary of points.
+ :param int batch_idx: The index of the current batch.
+ :return: The aggregated scalar loss for the batch.
+ :rtype: torch.Tensor
+ """
+ # Initialize a dictionary to hold the scalar losses for each condition
+ condition_losses = {}
+
+ # Loop through each condition in the batch and compute its scalar loss
+ for condition_name, data in batch:
+
+ # Compute the scalar loss for the current condition
+ condition_losses[condition_name] = self._compute_condition_loss(
+ condition=self.problem.conditions[condition_name],
+ data=dict(data),
+ batch_idx=batch_idx,
+ )
+
+ # Clamp parameters - null operation if problem is not InverseProblem
+ self._clamp_params()
+
+ # Log the individual condition losses
+ for name, value in condition_losses.items():
+ self.log(
+ name=f"{name}_loss",
+ value=value.item(),
+ batch_size=self.get_batch_size(batch),
+ **self.trainer.logging_kwargs,
+ )
+
+ # Aggregate into the total loss using the weighting scheme
+ aggregated_loss = self.weighting.aggregate(condition_losses)
+
+ return aggregated_loss.as_subclass(torch.Tensor)
diff --git a/pina/_src/solver/mixin/ensemble_mixin.py b/pina/_src/solver/mixin/ensemble_mixin.py
new file mode 100644
index 000000000..17757fc96
--- /dev/null
+++ b/pina/_src/solver/mixin/ensemble_mixin.py
@@ -0,0 +1,81 @@
+"""Module for the ensemble mixin class."""
+
+import torch
+from pina._src.solver.base_solver import BaseSolver
+from pina._src.solver.mixin.multi_model_mixin import MultiModelMixin
+
+
+class EnsembleMixin(MultiModelMixin):
+ """
+ Mixin that defines the forward pass and optimizer configuration for solvers
+ backed by an ensemble of models. Provides properties to access the models,
+ optimizers, and schedulers.
+
+ Designed to be used in combination with any solver inheriting from
+ :class:`~pina._src.solver.base_solver.BaseSolver`.
+ """
+
+ def forward(self, x):
+ """
+ Forward pass for ensemble solvers. If an active model index is set, only
+ that model is evaluated. Otherwise, all models are evaluated and their
+ outputs are stacked together.
+
+ :param x: The input data.
+ :type x: torch.Tensor | LabelTensor | Data | Graph
+ :return: The output of all models stacked together.
+ :rtype: torch.Tensor | LabelTensor | Data | Graph
+ """
+ # Retrieve the index of the active model if set
+ active_idx = getattr(self, "_active_model_idx", None)
+
+ # If an active model index is set, evaluate only that model
+ if active_idx is not None:
+ return self.models[active_idx](x)
+
+ # Otherwise, evaluate all models and stack outputs
+ return torch.stack(
+ [self.models[idx](x) for idx in range(self.num_models)]
+ )
+
+ def _compute_condition_loss(self, condition, data, batch_idx):
+ """
+ Compute the scalar loss for a given condition and its data.
+
+ :param BaseCondition condition: The condition for which to compute the
+ loss.
+ :param dict data: The data corresponding to the condition.
+ :param int batch_idx: The index of the current batch.
+ :return: The scalar loss for the condition.
+ :rtype: torch.Tensor
+ """
+ # Initialize model losses for the current condition
+ model_losses = []
+
+ # Restore the active model index if it was set, else set it to None
+ previous_active_model_idx = getattr(self, "_active_model_idx", None)
+
+ # Try - finally to ensure active model index is always restored
+ try:
+
+ # Iterate over all ensemble models to compute individual losses
+ for model_idx in range(self.num_models):
+
+ # Set the active model index for the current iteration
+ self._active_model_idx = model_idx
+
+ # Compute the scalar loss for the current model and condition
+ condition_scalar_loss = BaseSolver._compute_condition_loss(
+ self, condition, data, batch_idx
+ )
+
+ # Store the computed loss for the current model
+ model_losses.append(condition_scalar_loss)
+
+ # Ensure that the active model index is always restored
+ finally:
+
+ # Restore the previous active model index after computation
+ self._active_model_idx = previous_active_model_idx
+
+ return torch.stack(model_losses).mean()
diff --git a/pina/_src/solver/mixin/gradient_enhanced_mixin.py b/pina/_src/solver/mixin/gradient_enhanced_mixin.py
new file mode 100644
index 000000000..8a492dc1e
--- /dev/null
+++ b/pina/_src/solver/mixin/gradient_enhanced_mixin.py
@@ -0,0 +1,148 @@
+"""Module for the gradient-enhanced mixin class."""
+
+import torch
+from pina._src.problem.spatial_problem import SpatialProblem
+from pina._src.core.utils import check_consistency
+from pina._src.core.operator import grad
+
+
+class GradientEnhancedMixin:
+ """
+ Mixin that augments residual losses with a gradient-based regularization
+ term.
+
+ The additional penalty is the norm of the residual gradient with respect
+ to the spatial input variables. It is only applied to the conditions whose
+ names are listed in ``regularized_conditions``.
+
+ Designed to be used in combination with any solver inheriting from
+ :class:`~pina._src.solver.base_solver.BaseSolver` and using
+ :class:`~pina._src.core.tensor.label_tensor.LabelTensor` as input.
+ """
+
+ def _init_gradient_enhanced_components(
+ self, regularization_weight=1.0, regularized_conditions=None
+ ):
+ """
+ Initialize the gradient-enhancement parameters.
+
+ :param regularization_weight: The weight of the gradient regularization
+ term. Default is ``1.0``.
+ :type regularization_weight: float | int
+ :param regularized_conditions: The names of the conditions that should
+ receive gradient regularization. If ``None``, all conditions are
+ regularized. Default is ``None``.
+ :type regularized_conditions: str | list[str]
+ :raises ValueError: If ``regularization_weight`` is not a float or int.
+ :raises ValueError: If ``regularized_conditions`` is not a string or a
+ list of strings.
+ :raises ValueError: If ``problem`` is not an instance of
+ :class:`~pina._src.problem.spatial_problem.SpatialProblem`.
+ :raises ValueError: If the solver's input data are not instances of
+ :class:`~pina._src.core.tensor.label_tensor.LabelTensor`.
+ :raises ValueError: If any of the specified ``regularized_conditions``
+ are not present in the ``problem``'s conditions.
+ """
+ # Use all conditions if regularized_conditions is None
+ if regularized_conditions is None:
+ regularized_conditions = list(self.problem.conditions.keys())
+
+ # Check consistency
+ check_consistency(regularization_weight, (float, int))
+ check_consistency(regularized_conditions, str)
+
+ # Map conditions to list if a single condition is provided
+ if not isinstance(regularized_conditions, (list, tuple)):
+ regularized_conditions = [regularized_conditions]
+
+ # Assert the problem is instance of SpatialProblem
+ if not isinstance(self.problem, SpatialProblem):
+ raise ValueError(
+ "Gradient-enhanced regularization requires the problem to be "
+ f"an instance of SpatialProblem. Got {type(self.problem)}."
+ )
+
+ # Ensure that the solver is using LabelTensors as input
+ if not self.use_lt:
+ raise ValueError(
+ "Gradient-enhanced regularization requires the solver to use "
+ f"LabelTensors as input. Got use_lt={self.use_lt}."
+ )
+
+ # Ensure that all regularized conditions are present in the problem
+ problem_conditions = set(self.problem.conditions.keys())
+ for condition in regularized_conditions:
+ if condition not in problem_conditions:
+ raise ValueError(
+ f"Condition '{condition}' is not present in the problem."
+ )
+
+ # Initialize the gradient-enhancement parameters
+ self.regularization_weight = regularization_weight
+ self.regularized_conditions = regularized_conditions
+
+ def _prepare_condition_data(self, data):
+ """
+ Prepare the condition data for loss computation. This method can be
+ overridden by mixins to implement specific data preparation steps, such
+ as enabling gradient tracking for inputs in gradient-enhanced solvers.
+
+ :param dict data: The original condition data.
+ :return: The prepared condition data.
+ :rtype: dict
+ """
+ # If data does not require grad, force requires_grad to True
+ if "input" in data and not data["input"].requires_grad:
+ data["input"].requires_grad_(True)
+
+ return data
+
+ def _regularize_condition_loss(
+ self,
+ condition_tensor_loss,
+ condition_name,
+ data,
+ batch_idx,
+ ):
+ """
+ Regularize the condition loss if needed. This method can be overridden
+ by mixins to implement specific regularization strategies, such as
+ adding a gradient penalty in gradient-enhanced solvers or applying
+ residual-based attention.
+
+ :param condition_tensor_loss: The original tensor loss for the
+ condition.
+ :type condition_tensor_loss: torch.Tensor | LabelTensor
+ :param str condition_name: The name of the condition.
+ :param dict data: The data corresponding to the condition.
+ :param int batch_idx: The index of the current batch.
+ :return: The regularized tensor loss for the condition.
+ :rtype: torch.Tensor | LabelTensor
+ """
+ # Regularize the loss with the gradient penalty if needed
+ if condition_name in self.regularized_conditions:
+
+ # Apply labels to the residual tensor for gradient computation
+ self.residual_tensor.labels = [
+ f"res_{i}" for i in range(self.residual_tensor.shape[1])
+ ]
+
+ # Compute the gradient of the residual with respect to spatial input
+ residual_gradient = grad(
+ output_=self.residual_tensor,
+ input_=data["input"],
+ d=self.problem.spatial_variables,
+ )
+
+ # Compute the norm of the residual gradient
+ residual_gradient_norm = self._loss_fn(
+ residual_gradient, torch.zeros_like(residual_gradient)
+ )
+
+ # Compute the gradient penalty term
+ penalty = self.regularization_weight * residual_gradient_norm
+
+ # Add the gradient penalty to the original condition tensor loss
+ condition_tensor_loss = condition_tensor_loss + penalty
+
+ return condition_tensor_loss
diff --git a/pina/_src/solver/mixin/manual_optimization_mixin.py b/pina/_src/solver/mixin/manual_optimization_mixin.py
new file mode 100644
index 000000000..bef6380a2
--- /dev/null
+++ b/pina/_src/solver/mixin/manual_optimization_mixin.py
@@ -0,0 +1,66 @@
+"""Module for the manual optimization mixin class."""
+
+
+class ManualOptimizationMixin:
+ """
+ Mixin that handles Lightning manual optimization loops, useful for solvers
+ that require explicit control over optimization steps, such as those with
+ multiple optimizers or custom training loops.
+
+ Designed to be used in combination with any solver inheriting from
+ :class:`~pina._src.solver.base_solver.BaseSolver`.
+ """
+
+ def _init_manual_optimization(self):
+ """
+ Disable Lightning's automatic optimization.
+ """
+ self.automatic_optimization = False
+
+ def training_step(self, batch, batch_idx):
+ """
+ Solver training step.
+
+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
+ tuple containing a condition name and a dictionary of points.
+ :param int batch_idx: The index of the current batch.
+ :return: The loss of the training step.
+ :rtype: torch.Tensor
+ """
+ # Zero the gradients of all optimizers
+ for opt in self.optimizers:
+ opt.instance.zero_grad()
+
+ # Perform the forward pass and compute the loss
+ loss = super().training_step(batch, batch_idx)
+
+ # Perform the backward pass
+ self.manual_backward(loss)
+
+ # Step the optimizers and schedulers
+ for opt, sched in zip(self.optimizers, self.schedulers):
+ opt.instance.step()
+ sched.instance.step()
+
+ return loss
+
+ def on_train_batch_end(self, outputs, batch, batch_idx):
+ """
+ Keep Lightning's manual optimization progress counters in sync.
+
+ This hook increments the completed optimization-step counter used by
+ Lightning's manual optimization loop, then delegates to the parent
+ implementation.
+
+ :param torch.Tensor outputs: The loss of the training step.
+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
+ tuple containing a condition name and a dictionary of points.
+ :param int batch_idx: The index of the current batch.
+ :return: The result returned by the parent class implementation.
+ :rtype: Any
+ """
+ # Sync the manual optimization progress counters in Lightning's loop
+ epoch_loop = self.trainer.fit_loop.epoch_loop
+ epoch_loop.manual_optimization.optim_step_progress.total.completed += 1
+
+ return super().on_train_batch_end(outputs, batch, batch_idx)
diff --git a/pina/_src/solver/mixin/multi_model_mixin.py b/pina/_src/solver/mixin/multi_model_mixin.py
new file mode 100644
index 000000000..723020fbb
--- /dev/null
+++ b/pina/_src/solver/mixin/multi_model_mixin.py
@@ -0,0 +1,103 @@
+"""Module for the multi-model mixin class."""
+
+import torch
+from pina._src.problem.inverse_problem import InverseProblem
+
+
+class MultiModelMixin:
+ """
+ Mixin that defines the forward pass and optimizer configuration for solvers
+ backed by multiple models. Provides properties to access the models,
+ optimizers, and schedulers.
+
+ Designed to be used in combination with any solver inheriting from
+ :class:`~pina._src.solver.base_solver.BaseSolver`.
+ """
+
+ def forward(self, x):
+ """
+ The forward pass implementation that evaluates all models and returns a
+ stacked tensor of their outputs.
+
+ :param x: The input data.
+ :type x: torch.Tensor | LabelTensor | Data | Graph
+ :return: The output of all models stacked together.
+ :rtype: torch.Tensor | LabelTensor | Data | Graph
+ """
+ return torch.stack(
+ [self.models[idx](x) for idx in range(self.num_models)]
+ )
+
+ def configure_optimizers(self):
+ """
+ Configure the optimizers and schedulers for all models.
+
+ :return: The optimizer and the scheduler
+ :rtype: tuple[list[TorchOptimizer], list[TorchScheduler]]
+ """
+ # Iterate over models, optimizers, and schedulers to hook them together
+ for optimizer, scheduler, model in zip(
+ self.optimizers, self.schedulers, self.models
+ ):
+
+ # Hook the optimizer to the model parameters
+ optimizer.hook(model.parameters())
+
+ # Add parameter group for inverse problems if needed
+ if isinstance(self.problem, InverseProblem):
+ optimizer.instance.add_param_group(
+ {
+ "params": [
+ self._params[var]
+ for var in self.problem.unknown_variables
+ ]
+ }
+ )
+
+ # Hook the scheduler to the optimizer
+ scheduler.hook(optimizer)
+
+ return (
+ [optimizer.instance for optimizer in self.optimizers],
+ [scheduler.instance for scheduler in self.schedulers],
+ )
+
+ @property
+ def models(self):
+ """
+ The models used by the solver.
+
+ :return: The models used by the solver.
+ :rtype: list[torch.nn.Module]
+ """
+ return self._pina_models
+
+ @property
+ def optimizers(self):
+ """
+ The optimizers used by the solver.
+
+ :return: The optimizers used by the solver.
+ :rtype: list[TorchOptimizer]
+ """
+ return self._pina_optimizers
+
+ @property
+ def schedulers(self):
+ """
+ The schedulers used by the solver.
+
+ :return: The schedulers used by the solver.
+ :rtype: list[TorchScheduler]
+ """
+ return self._pina_schedulers
+
+ @property
+ def num_models(self):
+ """
+ The number of models used by the solver.
+
+ :return: The number of models used by the solver.
+ :rtype: int
+ """
+ return len(self.models)
diff --git a/pina/_src/solver/mixin/physics_informed_mixin.py b/pina/_src/solver/mixin/physics_informed_mixin.py
new file mode 100644
index 000000000..04229ff65
--- /dev/null
+++ b/pina/_src/solver/mixin/physics_informed_mixin.py
@@ -0,0 +1,40 @@
+"""Module for the physics-informed mixin class."""
+
+import torch
+
+
+class PhysicsInformedMixin:
+ """
+ Mixin that enables physics-informed training by ensuring gradients are
+ enabled during validation and testing, which is necessary for computing
+ physics residuals.
+
+ Designed to be used in combination with any solver inheriting from
+ :class:`~pina._src.solver.base_solver.BaseSolver`.
+ """
+
+ @torch.enable_grad()
+ def validation_step(self, batch, batch_idx):
+ """
+ Solver validation step.
+
+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
+ tuple containing a condition name and a dictionary of points.
+ :param int batch_idx: The index of the current batch.
+ :return: The loss of the training step.
+ :rtype: torch.Tensor
+ """
+ return super().validation_step(batch, batch_idx)
+
+ @torch.enable_grad()
+ def test_step(self, batch, batch_idx):
+ """
+ Solver test step.
+
+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
+ tuple containing a condition name and a dictionary of points.
+ :param int batch_idx: The index of the current batch.
+ :return: The loss of the training step.
+ :rtype: torch.Tensor
+ """
+ return super().test_step(batch, batch_idx)
diff --git a/pina/_src/solver/mixin/residual_based_attention_mixin.py b/pina/_src/solver/mixin/residual_based_attention_mixin.py
new file mode 100644
index 000000000..04b72fa4e
--- /dev/null
+++ b/pina/_src/solver/mixin/residual_based_attention_mixin.py
@@ -0,0 +1,150 @@
+"""Module for the residual-based attention mixin class."""
+
+import torch
+from pina._src.core.utils import check_consistency
+from pina._src.condition.domain_equation_condition import (
+ DomainEquationCondition,
+)
+
+
+class ResidualBasedAttentionMixin:
+ """
+ Mixin that augments the residual loss with an attention mechanism based on
+ the residual values.
+
+ The attention weights are computed as a function of the residuals, and they
+ are used to weight the contribution of each condition to the overall loss.
+ This allows the solver to focus more on conditions with larger residuals,
+ potentially improving convergence and accuracy.
+
+ Designed to be used in combination with any solver inheriting from
+ :class:`~pina._src.solver.base_solver.BaseSolver`.
+ """
+
+ def _init_residual_attention_components(
+ self, eta=0.001, gamma=0.999, regularized_conditions=None
+ ):
+ """
+ Initialize the residual-based attention parameters.
+
+ :param eta: The learning rate for the residual-based attention weights
+ update. Default is ``0.001``.
+ :type eta: float | int
+ :param float gamma: The decay factor for the residual-based attention
+ mechanism. Default is ``0.999``.
+ :param regularized_conditions: The names of the conditions that should
+ receive attention regularization. If ``None``, all conditions are
+ regularized. Default is ``None``.
+ :type regularized_conditions: str | list[str]
+ :raises ValueError: If ``eta`` is not a positive float or int.
+ :raises ValueError: If ``gamma`` is not a float in the range (0, 1).
+ :raises ValueError: If ``regularized_conditions`` is not a string or a
+ list of strings.
+ :raises ValueError: If any of the specified ``regularized_conditions``
+ are not present in the ``problem``'s conditions.
+ """
+ # Use all conditions if regularized_conditions is None
+ if regularized_conditions is None:
+ regularized_conditions = list(self.problem.conditions.keys())
+
+ # Check consistency
+ check_consistency(eta, (float, int))
+ check_consistency(gamma, float)
+ check_consistency(regularized_conditions, str)
+
+ # Assert gamma is in range (0, 1)
+ if not 0 < gamma < 1:
+ raise ValueError("gamma must be in range (0, 1)")
+
+ # Assert eta is positive
+ if eta <= 0:
+ raise ValueError("eta must be positive")
+
+ # Map conditions to list if a single condition is provided
+ if not isinstance(regularized_conditions, (list, tuple)):
+ regularized_conditions = [regularized_conditions]
+
+ # Ensure that all regularized conditions are present in the problem
+ problem_conditions = set(self.problem.conditions.keys())
+ for condition in regularized_conditions:
+ if condition not in problem_conditions:
+ raise ValueError(
+ f"Condition '{condition}' is not present in the problem."
+ )
+
+ # Initialize residual-based attention parameters
+ self.regularized_conditions = regularized_conditions
+ self.gamma = gamma
+ self.eta = eta
+ self.weight_buffers = {}
+
+ # Iterate over all conditions to initialize the attention weights
+ for cond in self.regularized_conditions:
+
+ # Get the condition object
+ condition = self.problem.conditions[cond]
+
+ # Determine the number of points in the condition
+ if isinstance(condition, DomainEquationCondition):
+ n_pts = self.problem._discretised_domains[cond].shape[0]
+ else:
+ n_pts = condition.data.input.shape[0]
+
+ # Register the attention weights as a buffer in the module
+ self.register_buffer(f"weight_{cond}", torch.zeros((n_pts, 1)))
+ self.weight_buffers[cond] = f"weight_{cond}"
+
+ def _regularize_condition_loss(
+ self,
+ condition_tensor_loss,
+ condition_name,
+ data,
+ batch_idx,
+ ):
+ """
+ Regularize the condition loss if needed. This method can be overridden
+ by mixins to implement specific regularization strategies, such as
+ adding a gradient penalty in gradient-enhanced solvers or applying
+ residual-based attention.
+
+ :param condition_tensor_loss: The original tensor loss for the
+ condition.
+ :type condition_tensor_loss: torch.Tensor | LabelTensor
+ :param str condition_name: The name of the condition.
+ :param dict data: The data corresponding to the condition.
+ :param int batch_idx: The index of the current batch.
+ :return: The regularized tensor loss for the condition.
+ :rtype: torch.Tensor | LabelTensor
+ """
+ # Apply residual-based attention mechanism if needed
+ if condition_name in self.regularized_conditions:
+
+ # Compute the normalized residuals norm for the current condition
+ res_abs = torch.linalg.vector_norm(
+ self.residual_tensor, ord=2, dim=1, keepdim=True
+ )
+ res_norm = res_abs / (res_abs.max() + 1e-12)
+
+ # Get the correct indices to retrieve the weights for the batch
+ len_residuals = self.residual_tensor.shape[0]
+
+ # Get the weights buffer for the current condition
+ weights = getattr(self, self.weight_buffers[condition_name])
+
+ # Get the total number of points, together with the start / end idx
+ total_points = weights.shape[0]
+ start = (batch_idx * len_residuals) % total_points
+ end = start + len_residuals
+
+ # Retrieve the weights for the current batch using modular indexing
+ idx = torch.arange(start, end, device=weights.device)
+ idx = idx % total_points
+
+ # Update weights
+ with torch.no_grad():
+ weights[idx] = self.gamma * weights[idx] + self.eta * res_norm
+
+ # Weight the condition tensor loss with attention weights
+ condition_tensor_loss = condition_tensor_loss * weights[idx]
+
+ return condition_tensor_loss
diff --git a/pina/_src/solver/mixin/single_model_mixin.py b/pina/_src/solver/mixin/single_model_mixin.py
new file mode 100644
index 000000000..74af1ab1a
--- /dev/null
+++ b/pina/_src/solver/mixin/single_model_mixin.py
@@ -0,0 +1,82 @@
+"""Module for the single-model mixin class."""
+
+from pina._src.problem.inverse_problem import InverseProblem
+
+
+class SingleModelMixin:
+ """
+ Mixin that defines the forward pass and optimizer configuration for solvers
+ backed by exactly one model. Provides properties to access the single model,
+ optimizer, and scheduler.
+
+ Designed to be used in combination with any solver inheriting from
+ :class:`~pina._src.solver.base_solver.BaseSolver`.
+ """
+
+ def forward(self, x):
+ """
+ The forward pass implementation for the single model, which simply
+ evaluates the model on the input.
+
+ :param x: The input data.
+ :type x: torch.Tensor | LabelTensor | Data | Graph
+ :return: The output of the single model.
+ :rtype: torch.Tensor | LabelTensor | Data | Graph
+ """
+ return self.model(x)
+
+ def configure_optimizers(self):
+ """
+ Configure the optimizer and scheduler for the single model.
+
+ :return: The optimizer and the scheduler
+ :rtype: tuple[list[TorchOptimizer], list[TorchScheduler]]
+ """
+ # Hook the optimizer to the model parameters
+ self.optimizer.hook(self.model.parameters())
+
+ # Add parameter group for inverse problems if needed
+ if isinstance(self.problem, InverseProblem):
+ self.optimizer.instance.add_param_group(
+ {
+ "params": [
+ self._params[var]
+ for var in self.problem.unknown_variables
+ ]
+ }
+ )
+
+ # Hook the scheduler to the optimizer
+ self.scheduler.hook(self.optimizer)
+
+ return ([self.optimizer.instance], [self.scheduler.instance])
+
+ @property
+ def model(self):
+ """
+ The single model used by the solver.
+
+ :return: The single model used by the solver.
+ :rtype: torch.nn.Module
+ """
+ return self._pina_models[0]
+
+ @property
+ def optimizer(self):
+ """
+ The optimizer used by the solver.
+
+ :return: The optimizer used by the solver.
+ :rtype: TorchOptimizer
+ """
+ return self._pina_optimizers[0]
+
+ @property
+ def scheduler(self):
+ """
+ The scheduler used by the solver.
+
+ :return: The scheduler used by the solver.
+ :rtype: TorchScheduler
+ """
+ return self._pina_schedulers[0]
diff --git a/pina/_src/solver/multi_model_solver.py b/pina/_src/solver/multi_model_solver.py
new file mode 100644
index 000000000..3fdec7d9c
--- /dev/null
+++ b/pina/_src/solver/multi_model_solver.py
@@ -0,0 +1,82 @@
+"""Module for the multi-model solver class."""
+
+from pina._src.solver.mixin.multi_model_mixin import MultiModelMixin
+from pina._src.solver.base_solver import BaseSolver
+from pina._src.solver.mixin.manual_optimization_mixin import (
+ ManualOptimizationMixin,
+)
+from pina._src.solver.mixin.condition_aggregator_mixin import (
+ ConditionAggregatorMixin,
+)
+
+
+class MultiModelSolver(
+ ManualOptimizationMixin,
+ MultiModelMixin,
+ ConditionAggregatorMixin,
+ BaseSolver,
+):
+ """
+ Base class for implementing multi-model solvers.
+
+ This class provides the standard starting point for solvers based on
+ multiple models. It combines the shared solver machinery from
+ :class:`~pina._src.solver.base_solver.BaseSolver` with multi-model handling,
+ manual optimization, and condition-wise loss aggregation.
+
+ Subclasses can inherit from this class to implement solver-specific behavior
+ while reusing the common logic for model registration, optimizer and
+ scheduler setup, manual optimization, loss evaluation, weighting, and
+ aggregation across problem conditions.
+ """
+
+ def __init__(
+ self,
+ problem,
+ models,
+ optimizers=None,
+ schedulers=None,
+ weighting=None,
+ loss=None,
+ use_lt=True,
+ ):
+ """
+ Initialization of the :class:`MultiModelSolver` class.
+
+ :param BaseProblem problem: The problem to be solved.
+ :param models: The model or list of models used by the solver.
+ :type models: torch.nn.Module | list[torch.nn.Module]
+ :param optimizers: The optimizer or list of optimizers used by the
+ solver. If ``None``, the ``torch.optim.Adam`` optimizer with a
+ learning rate of ``0.001`` is used for each model.
+ Default is ``None``.
+ :type optimizers: TorchOptimizer | list[TorchOptimizer]
+ :param schedulers: The scheduler or list of schedulers used by the
+ solver. If ``None``, the ``torch.optim.lr_scheduler.ConstantLR``
+ scheduler with a factor of ``1.0`` is used for each model.
+ Default is ``None``.
+ :type schedulers: TorchScheduler | list[TorchScheduler]
+ :param BaseWeighting weighting: The weighting strategy used to combine
+ condition losses. If ``None``, no weighting is applied. Default is
+ ``None``.
+ :param loss: The loss function used to compute residual losses.
+ If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``.
+ :param bool use_lt: If ``True``, the solver uses LabelTensors as input.
+ Default is ``True``.
+ """
+
+ # Initialize the base solver
+ BaseSolver.__init__(self, problem=problem, use_lt=use_lt)
+
+ # Initialize the components of the solver
+ self._init_solver_components(
+ models=models,
+ optimizers=optimizers,
+ schedulers=schedulers,
+ )
+
+ # Initialize the weighting scheme for the conditions and the loss
+ self._init_weighting_and_loss(weighting=weighting, loss=loss)
+
+ # Activate manual optimization
+ self._init_manual_optimization()
diff --git a/pina/_src/solver/physics_informed_ensemble_solver.py b/pina/_src/solver/physics_informed_ensemble_solver.py
new file mode 100644
index 000000000..15cc7e1e6
--- /dev/null
+++ b/pina/_src/solver/physics_informed_ensemble_solver.py
@@ -0,0 +1,96 @@
+"""Module for the physics-informed ensemble solver class."""
+
+from pina._src.solver.mixin.physics_informed_mixin import PhysicsInformedMixin
+from pina._src.condition.input_equation_condition import InputEquationCondition
+from pina._src.condition.input_target_condition import InputTargetCondition
+from pina._src.solver.ensemble_solver import EnsembleSolver
+from pina._src.condition.domain_equation_condition import (
+ DomainEquationCondition,
+)
+
+
+class PhysicsInformedEnsembleSolver(PhysicsInformedMixin, EnsembleSolver):
+ r"""
+ Ensemble-model solver for physics-informed learning problems.
+
+ This solver approximates the solution of a differential problem using an
+ ensemble of models. It is intended for problems whose conditions may include
+ supervised data, equation residuals evaluated on input points, and equation
+ residuals sampled from domains.
+
+ Given an ensemble of models :math:`\{\mathcal{M}_j\}_{j=1}^{M}`, the
+ predicted solution of each model is
+
+ .. math::
+
+ \hat{\mathbf{u}}^{(j)}(\mathbf{x}) = \mathcal{M}_j(\mathbf{x}),
+ \qquad j = 1, \ldots, M.
+
+ The solver minimizes the residuals of the differential operators defining
+ the problem for each model in the ensemble. For a problem with governing
+ equation operator :math:`\mathcal{A}` in the domain :math:`\Omega` and
+ boundary operator :math:`\mathcal{B}` on the boundary
+ :math:`\partial\Omega`, the objective can be written as
+
+ .. math::
+
+ \mathcal{L}_{\mathrm{problem}} = \frac{1}{M} \sum_{j=1}^{M}
+ \left[ \frac{1}{N_{\Omega}} \sum_{i=1}^{N_{\Omega}} \mathcal{L}
+ \left( \mathcal{A}[\hat{\mathbf{u}}^{(j)}](\mathbf{x}_i) \right)
+ + \frac{1}{N_{\partial\Omega}} \sum_{i=1}^{N_{\partial\Omega}}
+ \mathcal{L}
+ \left( \mathcal{B}[\hat{\mathbf{u}}^{(j)}](\mathbf{x}_i) \right)
+ \right],
+
+ where :math:`\mathcal{L}` is the selected loss function, typically the
+ mean squared error.
+ """
+
+ # Accepted conditions types for this solver
+ accepted_conditions_types = (
+ InputTargetCondition,
+ InputEquationCondition,
+ DomainEquationCondition,
+ )
+
+ def __init__(
+ self,
+ problem,
+ models,
+ optimizers=None,
+ schedulers=None,
+ weighting=None,
+ loss=None,
+ ):
+ """
+ Initialization of the :class:`PhysicsInformedEnsembleSolver` class.
+
+ :param BaseProblem problem: The problem to be solved.
+ :param models: The model or list of models used by the solver.
+ :type models: torch.nn.Module | list[torch.nn.Module]
+ :param optimizers: The optimizer or list of optimizers used by the
+ solver. If ``None``, the ``torch.optim.Adam`` optimizer with a
+ learning rate of ``0.001`` is used for each model.
+ Default is ``None``.
+ :type optimizers: TorchOptimizer | list[TorchOptimizer]
+ :param schedulers: The scheduler or list of schedulers used by the
+ solver. If ``None``, the ``torch.optim.lr_scheduler.ConstantLR``
+ scheduler with a factor of ``1.0`` is used for each model.
+ Default is ``None``.
+ :type schedulers: TorchScheduler | list[TorchScheduler]
+ :param BaseWeighting weighting: The weighting strategy used to combine
+ condition losses. If ``None``, no weighting is applied. Default is
+ ``None``.
+ :param loss: The loss function used to compute residual losses.
+ If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``.
+ """
+ EnsembleSolver.__init__(
+ self,
+ problem=problem,
+ models=models,
+ optimizers=optimizers,
+ schedulers=schedulers,
+ weighting=weighting,
+ loss=loss,
+ use_lt=True,
+ )
diff --git a/pina/_src/solver/physics_informed_single_model_solver.py b/pina/_src/solver/physics_informed_single_model_solver.py
new file mode 100644
index 000000000..1a5f783a2
--- /dev/null
+++ b/pina/_src/solver/physics_informed_single_model_solver.py
@@ -0,0 +1,96 @@
+"""Module for the physics-informed single-model solver class."""
+
+from pina._src.solver.mixin.physics_informed_mixin import PhysicsInformedMixin
+from pina._src.condition.input_equation_condition import InputEquationCondition
+from pina._src.condition.input_target_condition import InputTargetCondition
+from pina._src.solver.single_model_solver import SingleModelSolver
+from pina._src.condition.domain_equation_condition import (
+ DomainEquationCondition,
+)
+
+
+class PhysicsInformedSingleModelSolver(PhysicsInformedMixin, SingleModelSolver):
+ r"""
+ Single-model solver for physics-informed learning problems.
+
+ This solver approximates the solution of a differential problem using a
+ single model. It is intended for problems whose conditions may include
+ supervised data, equation residuals evaluated on input points, and equation
+ residuals sampled from domains.
+
+ Given a model :math:`\mathcal{M}`, the predicted solution is
+
+ .. math::
+
+ \hat{\mathbf{u}}(\mathbf{x}) = \mathcal{M}(\mathbf{x}).
+
+ The solver minimizes the residuals of the differential operators defining
+ the problem. For a problem with governing equation operator
+ :math:`\mathcal{A}` in the domain :math:`\Omega` and boundary operator
+ :math:`\mathcal{B}` on the boundary :math:`\partial\Omega`, the objective
+ can be written as
+
+ .. math::
+
+ \mathcal{L}_{\mathrm{problem}} = \frac{1}{N_{\Omega}}
+ \sum_{i=1}^{N_{\Omega}} \mathcal{L}
+ \left( \mathcal{A}[\hat{\mathbf{u}}](\mathbf{x}_i) \right)
+ + \frac{1}{N_{\partial\Omega}} \sum_{i=1}^{N_{\partial\Omega}}
+ \mathcal{L} \left( \mathcal{B}[\hat{\mathbf{u}}](\mathbf{x}_i) \right),
+
+ where :math:`\mathcal{L}` is the selected loss function, typically the
+ mean squared error.
+
+ .. seealso::
+
+ **Original reference**: Karniadakis, G. E., Kevrekidis, I. G., Lu, L.,
+ Perdikaris, P., Wang, S., & Yang, L. (2021).
+ *Physics-informed machine learning.*
+ Nature Reviews Physics, 3, 422-440.
+ DOI: `10.1038/s42254-021-00314-5
+ `_.
+ """
+
+ # Accepted conditions types for this solver
+ accepted_conditions_types = (
+ InputTargetCondition,
+ InputEquationCondition,
+ DomainEquationCondition,
+ )
+
+ def __init__(
+ self,
+ problem,
+ model,
+ optimizer=None,
+ scheduler=None,
+ weighting=None,
+ loss=None,
+ ):
+ """
+ Initialization of the :class:`PhysicsInformedSingleModelSolver` class.
+
+ :param BaseProblem problem: The problem to be solved.
+ :param torch.nn.Module model: The model used by the solver.
+ :param TorchOptimizer optimizer: The optimizer used by the solver.
+ If ``None``, the ``torch.optim.Adam`` optimizer with a learning rate
+ of ``0.001`` is used. Default is ``None``.
+ :param TorchScheduler scheduler: The scheduler used by the solver.
+ If ``None``, the ``torch.optim.lr_scheduler.ConstantLR`` scheduler
+ with a factor of ``1.0`` is used. Default is ``None``.
+ :param BaseWeighting weighting: The weighting strategy used to combine
+ condition losses. If ``None``, no weighting is applied. Default is
+ ``None``.
+ :param loss: The loss function used to compute residual losses.
+ If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``.
+ """
+ SingleModelSolver.__init__(
+ self,
+ problem=problem,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ weighting=weighting,
+ loss=loss,
+ use_lt=True,
+ )
diff --git a/pina/_src/solver/rba_physics_informed_single_model_solver.py b/pina/_src/solver/rba_physics_informed_single_model_solver.py
new file mode 100644
index 000000000..d19a8f229
--- /dev/null
+++ b/pina/_src/solver/rba_physics_informed_single_model_solver.py
@@ -0,0 +1,140 @@
+"""
+Module for the residual-based attention physics-informed single-model solver
+class.
+"""
+
+from pina._src.solver.mixin.physics_informed_mixin import PhysicsInformedMixin
+from pina._src.condition.input_equation_condition import InputEquationCondition
+from pina._src.condition.input_target_condition import InputTargetCondition
+from pina._src.solver.single_model_solver import SingleModelSolver
+from pina._src.solver.mixin.residual_based_attention_mixin import (
+ ResidualBasedAttentionMixin,
+)
+from pina._src.condition.domain_equation_condition import (
+ DomainEquationCondition,
+)
+
+
+class RBAPhysicsInformedSingleModelSolver(
+ PhysicsInformedMixin, ResidualBasedAttentionMixin, SingleModelSolver
+):
+ r"""
+ Residual-based attention solver for physics-informed learning problems.
+
+ This solver approximates the solution of a differential problem using a
+ single model equipped with residual-based attention weights. It can be used
+ for both forward and inverse problems.
+
+ Given a model :math:`\mathcal{M}`, the predicted solution is
+
+ .. math::
+
+ \hat{\mathbf{u}}(\mathbf{x}) = \mathcal{M}(\mathbf{x}).
+
+ The solver minimizes a weighted objective in which each residual
+ contribution is scaled by an attention weight. For a problem with governing
+ equation operator :math:`\mathcal{A}` in the domain :math:`\Omega` and
+ boundary operator :math:`\mathcal{B}` on the boundary
+ :math:`\partial\Omega`, the objective can be written as
+
+ .. math::
+
+ \mathcal{L}_{\mathrm{problem}} =
+ \frac{1}{N_{\Omega}} \sum_{i=1}^{N_{\Omega}}
+ \lambda_{\Omega}^{i} \mathcal{L}
+ \left( \mathcal{A}[\hat{\mathbf{u}}](\mathbf{x}_i) \right)
+ + \frac{1}{N_{\partial\Omega}} \sum_{i=1}^{N_{\partial\Omega}}
+ \lambda_{\partial\Omega}^{i} \mathcal{L}
+ \left( \mathcal{B}[\hat{\mathbf{u}}](\mathbf{x}_i) \right),
+
+ where :math:`\mathcal{L}` is the selected loss function, typically the
+ mean squared error, and :math:`\lambda_{\Omega}^{i}` and
+ :math:`\lambda_{\partial\Omega}^{i}` are the attention weights associated
+ with the domain and boundary residuals, respectively.
+
+ At each epoch, the attention weights are updated according to the magnitude
+ of the corresponding residuals:
+
+ .. math::
+
+ \lambda_i^{k+1} = \gamma \lambda_i^k + \eta \frac{|r_i|}{\max_j |r_j|},
+
+ where :math:`r_i` is the residual at point :math:`i`, :math:`\gamma` is the
+ decay rate, and :math:`\eta` is the learning rate used for the attention
+ weight update.
+
+ .. seealso::
+
+ **Original reference**: Anagnostopoulos, S. J., Toscano, J. D.,
+ Stergiopulos, N., & Karniadakis, G. E. (2024).
+ *Residual-based attention and connection to information bottleneck theory
+ in PINNs.*
+ Computer Methods in Applied Mechanics and Engineering, 421, 116805.
+ DOI: `10.1016/j.cma.2024.116805
+ `_.
+ """
+
+ # Accepted conditions types for this solver
+ accepted_conditions_types = (
+ InputTargetCondition,
+ InputEquationCondition,
+ DomainEquationCondition,
+ )
+
+ def __init__(
+ self,
+ problem,
+ model,
+ optimizer=None,
+ scheduler=None,
+ weighting=None,
+ loss=None,
+ eta=0.001,
+ gamma=0.999,
+ regularized_conditions=None,
+ ):
+ """
+ Initialization of the :class:`RBAPhysicsInformedSingleModelSolver`
+ class.
+
+ :param BaseProblem problem: The problem to be solved.
+ :param torch.nn.Module model: The model used by the solver.
+ :param TorchOptimizer optimizer: The optimizer used by the solver.
+ If ``None``, the ``torch.optim.Adam`` optimizer with a learning rate
+ of ``0.001`` is used. Default is ``None``.
+ :param TorchScheduler scheduler: The scheduler used by the solver.
+ If ``None``, the ``torch.optim.lr_scheduler.ConstantLR`` scheduler
+ with a factor of ``1.0`` is used. Default is ``None``.
+ :param BaseWeighting weighting: The weighting strategy used to combine
+ condition losses. If ``None``, no weighting is applied. Default is
+ ``None``.
+ :param loss: The loss function used to compute residual losses.
+ If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``.
+ :param eta: The learning rate for the residual-based attention weights
+ update. Default is ``0.001``.
+ :type eta: float | int
+ :param float gamma: The decay factor for the residual-based attention
+ mechanism. Default is ``0.999``.
+ :param regularized_conditions: The names of the conditions that should
+ receive attention regularization. If ``None``, all conditions are
+ regularized. Default is ``None``.
+ :type regularized_conditions: str | list[str]
+ """
+ # Initialize the parent class
+ SingleModelSolver.__init__(
+ self,
+ problem=problem,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ weighting=weighting,
+ loss=loss,
+ use_lt=True,
+ )
+
+ # Initialize the residual-based attention components
+ self._init_residual_attention_components(
+ eta=eta,
+ gamma=gamma,
+ regularized_conditions=regularized_conditions,
+ )
diff --git a/pina/_src/solver/self_adaptive_physics_informed_solver.py b/pina/_src/solver/self_adaptive_physics_informed_solver.py
new file mode 100644
index 000000000..7f2b4032a
--- /dev/null
+++ b/pina/_src/solver/self_adaptive_physics_informed_solver.py
@@ -0,0 +1,344 @@
+"""Module for the self-adaptive physics-informed multi-model solver."""
+
+import torch
+from pina._src.solver.mixin.physics_informed_mixin import PhysicsInformedMixin
+from pina._src.condition.input_equation_condition import InputEquationCondition
+from pina._src.condition.input_target_condition import InputTargetCondition
+from pina._src.solver.multi_model_solver import MultiModelSolver
+from pina._src.core.utils import check_consistency
+from pina._src.condition.domain_equation_condition import (
+ DomainEquationCondition,
+)
+
+
+class SelfAdaptivePhysicsInformedSolver(PhysicsInformedMixin, MultiModelSolver):
+ r"""
+ Multi-model solver for self-adaptive physics-informed learning problems.
+
+ This solver approximates the solution of a differential problem using a
+ trainable model together with condition-wise self-adaptive weights. It is
+ intended for problems whose conditions may include supervised data, equation
+ residuals evaluated on input points, and equation residuals sampled from
+ domains.
+
+ Given a model :math:`\mathcal{M}`, the predicted solution is
+
+ .. math::
+
+ \hat{\mathbf{u}}(\mathbf{x}) = \mathcal{M}(\mathbf{x}).
+
+ For each condition, the solver introduces trainable pointwise weights. These
+ weights are passed through a user-defined weight function :math:`m`,
+ typically chosen to keep the effective weights bounded or positive. The
+ resulting weighted objective encourages the model to focus on regions where
+ the residual is larger.
+
+ For a problem with governing equation operator :math:`\mathcal{A}` in the
+ domain :math:`\Omega` and boundary operator :math:`\mathcal{B}` on the
+ boundary :math:`\partial\Omega`, the objective can be written as
+
+ .. math::
+
+ \mathcal{L}_{\mathrm{problem}} = \frac{1}{N_{\Omega}}
+ \sum_{i=1}^{N_{\Omega}} m(\lambda_{\Omega}^{i}) \mathcal{L}
+ \left( \mathcal{A}[\hat{\mathbf{u}}](\mathbf{x}_i) \right)
+ + \frac{1}{N_{\partial\Omega}} \sum_{i=1}^{N_{\partial\Omega}}
+ m(\lambda_{\partial\Omega}^{i})
+ \mathcal{L} \left( \mathcal{B}[\hat{\mathbf{u}}](\mathbf{x}_i) \right),
+
+ where :math:`\lambda_{\Omega}^{i}` and :math:`\lambda_{\partial\Omega}^{i}`
+ are the self-adaptive weights associated with points in :math:`\Omega` and
+ :math:`\partial\Omega`, respectively, and :math:`\mathcal{L}` is the
+ selected loss function, typically the mean squared error.
+
+ The model parameters and the self-adaptive weights are optimized through a
+ min-max problem:
+
+ .. math::
+
+ \min_{\theta} \max_{\lambda} \mathcal{L}_{\mathrm{problem}},
+
+ where :math:`\theta` denotes the model parameters and :math:`\lambda`
+ denotes the collection of self-adaptive weights.
+
+ .. seealso::
+
+ **Original reference**: McClenny, L. D., & Braga-Neto, U. M. (2023).
+ *Self-adaptive physics-informed neural networks.*
+ Journal of Computational Physics, 474, 111722.
+ DOI: `10.1016/j.jcp.2022.111722
+ `_.
+ """
+
+ # Accepted conditions types for this solver
+ accepted_conditions_types = (
+ InputTargetCondition,
+ InputEquationCondition,
+ DomainEquationCondition,
+ )
+
+ def __init__(
+ self,
+ problem,
+ model,
+ weight_function=torch.nn.Sigmoid(),
+ optimizer_model=None,
+ optimizer_weights=None,
+ scheduler_model=None,
+ scheduler_weights=None,
+ weighting=None,
+ loss=None,
+ ):
+ """
+ Initialization of the :class:`SelfAdaptivePhysicsInformedSolver` class.
+
+ :param BaseProblem problem: The problem to be solved.
+ :param torch.nn.Module model: The model used by the solver.
+ :param torch.nn.Module weight_function: The weight function used to
+ compute self-adaptive weights. Default is ``torch.nn.Sigmoid()``.
+ :param TorchOptimizer optimizer_model: The optimizer of the main model.
+ If ``None``, the ``torch.optim.Adam`` optimizer with a learning rate
+ of ``0.001`` is used. Default is ``None``.
+ :param TorchOptimizer optimizer_weights: The optimizer of the
+ self-adaptive weights. If ``None``, the ``torch.optim.Adam``
+ optimizer with a learning rate of ``0.001`` is used.
+ Default is ``None``.
+ :param TorchScheduler scheduler_model: The scheduler of the main model.
+ If ``None``, the ``torch.optim.lr_scheduler.ConstantLR`` scheduler
+ with a factor of ``1.0`` is used. Default is ``None``.
+ :param TorchScheduler scheduler_weights: The scheduler of the
+ self-adaptive weights. If ``None``, the
+ ``torch.optim.lr_scheduler.ConstantLR`` scheduler with a factor of
+ ``1.0`` is used. Default is ``None``.
+ :param BaseWeighting weighting: The weighting strategy used to combine
+ condition losses. If ``None``, no weighting is applied. Default is
+ ``None``.
+ :param loss: The loss function used to compute residual losses.
+ If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``.
+ :raises ValueError: If ``weight_function`` is not a ``torch.nn.Module``.
+ :raises ValueError: If not all domains have been discretised.
+ """
+ # Check consistency
+ check_consistency(weight_function, torch.nn.Module)
+
+ # Check that all domains have been discretised
+ if not problem.are_all_domains_discretised:
+ raise ValueError(
+ "All domains must be discretised before initializing the "
+ "solver."
+ )
+
+ # Compute the number of points for each condition
+ num_points = {
+ cond: (
+ problem._discretised_domains[cond].shape[0]
+ if isinstance(problem.conditions[cond], DomainEquationCondition)
+ else problem.conditions[cond].data.input.shape[0]
+ )
+ for cond in problem.conditions
+ }
+
+ # Initialize weights container and per-condition parameters
+ weights = torch.nn.Module()
+
+ # Attach the weight function as a submodule
+ weights.func = weight_function
+
+ # Register a torch.nn.Parameter for each condition to store the weights
+ for cond in problem.conditions:
+ p = torch.nn.Parameter(torch.zeros(num_points[cond], 1))
+ setattr(weights, cond, p)
+
+ # Prepare optimizers
+ optimizers = (
+ [optimizer_model, optimizer_weights]
+ if any(o is not None for o in (optimizer_model, optimizer_weights))
+ else None
+ )
+
+ # Prepare schedulers
+ schedulers = (
+ [scheduler_model, scheduler_weights]
+ if any(s is not None for s in (scheduler_model, scheduler_weights))
+ else None
+ )
+
+ # Initialize the base solver
+ MultiModelSolver.__init__(
+ self,
+ problem=problem,
+ models=[model, weights],
+ optimizers=optimizers,
+ schedulers=schedulers,
+ weighting=weighting,
+ loss=loss,
+ use_lt=True,
+ )
+
+ def training_step(self, batch, batch_idx):
+ """
+ Solver training step.
+
+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
+ tuple containing a condition name and a dictionary of points.
+ :param int batch_idx: The index of the current batch.
+ :return: The loss of the training step.
+ :rtype: torch.Tensor
+ """
+ # Zero the gradients of weights optimizer and compute the loss
+ self.optimizer_weights.instance.zero_grad()
+ loss = self.batch_evaluation_step(batch, batch_idx)
+
+ # Perform the backward pass and complete a step for the weights
+ self.manual_backward(-loss)
+ self.optimizer_weights.instance.step()
+ self.scheduler_weights.instance.step()
+
+ # Zero the gradients of model optimizer and compute the loss again
+ self.optimizer_model.instance.zero_grad()
+ loss = self.batch_evaluation_step(batch, batch_idx)
+
+ # Perform the backward pass and complete a step for the model
+ self.manual_backward(loss)
+ self.optimizer_model.instance.step()
+ self.scheduler_model.instance.step()
+
+ # Log the training loss
+ self.log(
+ name="train_loss",
+ value=loss.item(),
+ batch_size=self.get_batch_size(batch),
+ **self.trainer.logging_kwargs,
+ )
+
+ return loss
+
+ def forward(self, x):
+ """
+ Forward pass through the model.
+
+ :param x: The input data.
+ :type x: torch.Tensor | LabelTensor | Data | Graph
+ :return: The output of the model.
+ :rtype: torch.Tensor | LabelTensor | Data | Graph
+ """
+ return self.model(x)
+
+ def _compute_condition_loss(self, condition, data, batch_idx):
+ """
+ Compute the scalar loss for a given condition and its data.
+
+ :param BaseCondition condition: The condition for which to compute the
+ loss.
+ :param dict data: The data corresponding to the condition.
+ :param int batch_idx: The index of the current batch.
+ :return: The scalar loss for the condition.
+ :rtype: torch.Tensor
+ """
+ # Clone the input tensor if it exists to avoid in-place modifications
+ if "input" in data and hasattr(data["input"], "clone"):
+ data = dict(data)
+ data["input"] = data["input"].clone()
+
+ # Prepare condition data, e.g. by enabling gradient for regularizations
+ data = self._prepare_condition_data(data=data)
+
+ # Compute and store the residual tensor for the condition
+ self.residual_tensor = condition.evaluate(data, self)
+
+ # Retrieve condition name for more complex weighting schemes
+ condition_name = condition.name
+
+ # Apply the activation function to the condition-specific weights
+ weight_param = getattr(self.weights, condition_name)
+ weight_tensor = self.weights.func(weight_param)
+
+ # Compute the tensor loss from the residual tensor
+ condition_tensor_loss = self._loss_from_residual(condition_name)
+
+ # Optional regularization hook, e.g gradient-enhanced or residual-based
+ condition_tensor_loss = self._regularize_condition_loss(
+ condition_tensor_loss=condition_tensor_loss,
+ condition_name=condition_name,
+ data=data,
+ batch_idx=batch_idx,
+ )
+
+ # Get the correct indices to retrieve the weights for the current batch
+ len_residuals = self.residual_tensor.shape[0]
+
+ # Get the total number of points, together with the start / end indices
+ total_points = weight_param.shape[0]
+ start = (batch_idx * len_residuals) % total_points
+ end = start + len_residuals
+
+ # Retrieve the weights for the current batch using modular indexing
+ idx = torch.arange(start, end, device=self.residual_tensor.device)
+ idx = idx % total_points
+
+ # Compute the scalar loss from the tensor loss and return it
+ condition_scalar_loss = self._apply_reduction(
+ condition_tensor_loss * weight_tensor[idx]
+ )
+
+ return condition_scalar_loss
+
+ @property
+ def model(self):
+ """
+ The single model used by the solver.
+
+ :return: The single model used by the solver.
+ :rtype: torch.nn.Module
+ """
+ return self._pina_models[0]
+
+ @property
+ def weights(self):
+ """
+ The self-adaptive weights used by the solver.
+
+ :return: The self-adaptive weights used by the solver.
+ :rtype: torch.nn.Module
+ """
+ return self._pina_models[1]
+
+ @property
+ def optimizer_model(self):
+ """
+ The optimizer for the model used by the solver.
+
+ :return: The optimizer for the model used by the solver.
+ :rtype: TorchOptimizer
+ """
+ return self.optimizers[0]
+
+ @property
+ def optimizer_weights(self):
+ """
+ The optimizer for the weights used by the solver.
+
+ :return: The optimizer for the weights used by the solver.
+ :rtype: TorchOptimizer
+ """
+ return self.optimizers[1]
+
+ @property
+ def scheduler_model(self):
+ """
+ The scheduler for the model used by the solver.
+
+ :return: The scheduler for the model used by the solver.
+ :rtype: TorchScheduler
+ """
+ return self.schedulers[0]
+
+ @property
+ def scheduler_weights(self):
+ """
+ The scheduler for the weights used by the solver.
+
+ :return: The scheduler for the weights used by the solver.
+ :rtype: TorchScheduler
+ """
+ return self.schedulers[1]
diff --git a/pina/_src/solver/single_model_solver.py b/pina/_src/solver/single_model_solver.py
new file mode 100644
index 000000000..265c431c9
--- /dev/null
+++ b/pina/_src/solver/single_model_solver.py
@@ -0,0 +1,65 @@
+"""Module for the single-model solver class."""
+
+from pina._src.solver.mixin.single_model_mixin import SingleModelMixin
+from pina._src.solver.base_solver import BaseSolver
+from pina._src.solver.mixin.condition_aggregator_mixin import (
+ ConditionAggregatorMixin,
+)
+
+
+class SingleModelSolver(SingleModelMixin, ConditionAggregatorMixin, BaseSolver):
+ """
+ Base class for implementing single-model solvers.
+
+ This class provides the standard starting point for solvers based on a
+ single model. It combines the shared solver machinery from
+ :class:`~pina._src.solver.base_solver.BaseSolver` with single-model handling
+ and condition-wise loss aggregation.
+
+ Subclasses can inherit from this class to implement solver-specific behavior
+ while reusing the common logic for model registration, optimizer and
+ scheduler setup, loss evaluation, weighting, and aggregation across problem
+ conditions.
+ """
+
+ def __init__(
+ self,
+ problem,
+ model,
+ optimizer=None,
+ scheduler=None,
+ weighting=None,
+ loss=None,
+ use_lt=True,
+ ):
+ """
+ Initialization of the :class:`SingleModelSolver` class.
+
+ :param BaseProblem problem: The problem to be solved.
+ :param torch.nn.Module model: The model used by the solver.
+ :param TorchOptimizer optimizer: The optimizer used by the solver.
+ If ``None``, the ``torch.optim.Adam`` optimizer with a learning rate
+ of ``0.001`` is used. Default is ``None``.
+ :param TorchScheduler scheduler: The scheduler used by the solver.
+ If ``None``, the ``torch.optim.lr_scheduler.ConstantLR`` scheduler
+ with a factor of ``1.0`` is used. Default is ``None``.
+ :param BaseWeighting weighting: The weighting strategy used to combine
+ condition losses. If ``None``, no weighting is applied. Default is
+ ``None``.
+ :param loss: The loss function used to compute residual losses.
+ If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``.
+ :param bool use_lt: If ``True``, the solver uses LabelTensors as input.
+ Default is ``True``.
+ """
+ # Initialize the base solver
+ BaseSolver.__init__(self, problem=problem, use_lt=use_lt)
+
+ # Initialize the components of the solver
+ self._init_solver_components(
+ models=model,
+ optimizers=optimizer,
+ schedulers=scheduler,
+ )
+
+ # Initialize the weighting scheme for the conditions and the loss
+ self._init_weighting_and_loss(weighting=weighting, loss=loss)
diff --git a/pina/_src/solver/solver_interface.py b/pina/_src/solver/solver_interface.py
new file mode 100644
index 000000000..c6cab1b18
--- /dev/null
+++ b/pina/_src/solver/solver_interface.py
@@ -0,0 +1,92 @@
+"""Module for the solver interface."""
+
+from abc import ABCMeta, abstractmethod
+import lightning
+
+
+class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
+ """
+ Abstract base class for PINA solvers. All specific solvers must inherit
+ from this interface. This class extends
+ :class:`~lightning.pytorch.core.LightningModule`, providing additional
+ functionalities for defining and optimizing Deep Learning models.
+
+ By inheriting from this base class, solvers gain access to built-in training
+ loops, logging utilities, and optimization techniques.
+ """
+
+ @abstractmethod
+ def training_step(self, batch, batch_idx):
+ """
+ Solver training step.
+
+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
+ tuple containing a condition name and a dictionary of points.
+ :param int batch_idx: The index of the current batch.
+ :return: The loss of the training step.
+ :rtype: torch.Tensor
+ """
+
+ @abstractmethod
+ def validation_step(self, batch, batch_idx):
+ """
+ Solver validation step.
+
+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
+ tuple containing a condition name and a dictionary of points.
+ :param int batch_idx: The index of the current batch.
+ :return: The loss of the training step.
+ :rtype: torch.Tensor
+ """
+
+ @abstractmethod
+ def test_step(self, batch, batch_idx):
+ """
+ Solver test step.
+
+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
+ tuple containing a condition name and a dictionary of points.
+ :param int batch_idx: The index of the current batch.
+ :return: The loss of the training step.
+ :rtype: torch.Tensor
+ """
+
+ @property
+ @abstractmethod
+ def problem(self):
+ """
+ The problem instance.
+
+ :return: The problem instance.
+ :rtype: :class:`~pina.problem.base_problem.BaseProblem`
+ """
+
+ @property
+ @abstractmethod
+ def use_lt(self):
+ """
+ Using LabelTensors as input during training.
+
+ :return: The use_lt attribute.
+ :rtype: bool
+ """
+
+ @property
+ @abstractmethod
+ def weighting(self):
+ """
+ The weighting schema used by the solver.
+
+ :return: The weighting schema used by the solver.
+ :rtype: :class:`~pina.weighting.base_weighting.BaseWeighting`
+ """
+
+ @property
+ @abstractmethod
+ def loss(self):
+ """
+ The element-wise loss module used by the solver.
+
+ :return: The element-wise loss module used by the solver.
+ :rtype: torch.nn.Module
+ """
diff --git a/pina/_src/solver/supervised_ensemble_solver.py b/pina/_src/solver/supervised_ensemble_solver.py
new file mode 100644
index 000000000..d602f3fe0
--- /dev/null
+++ b/pina/_src/solver/supervised_ensemble_solver.py
@@ -0,0 +1,84 @@
+"""Module for the supervised ensemble-model solver class."""
+
+from pina._src.condition.input_target_condition import InputTargetCondition
+from pina._src.solver.ensemble_solver import EnsembleSolver
+
+
+class SupervisedEnsembleSolver(EnsembleSolver):
+ r"""
+ Ensemble-model solver for supervised learning problems.
+
+ This solver approximates the mapping between input data and target data
+ using an ensemble of models. It is intended for problems whose conditions
+ are defined by input-target pairs and accepts only
+ :class:`~pina._src.condition.input_target_condition.InputTargetCondition`.
+
+ Given input samples :math:`\mathbf{s}_i`, target values
+ :math:`\mathbf{u}_i`, and an ensemble of models
+ :math:`\{\mathcal{M}_j\}_{j=1}^{M}`, the prediction of each model is
+
+ .. math::
+
+ \hat{\mathbf{u}}_{i}^{(j)} = \mathcal{M}_j(\mathbf{s}_i),
+ \qquad j = 1, \ldots, M.
+
+ The supervised training objective minimizes the discrepancy between the
+ target values and the ensemble predictions:
+
+ .. math::
+
+ \mathcal{L}_{\mathrm{problem}} = \frac{1}{M} \sum_{j=1}^{M}
+ \frac{1}{N} \sum_{i=1}^{N} \mathcal{L}
+ \left( \mathbf{u}_i - \hat{\mathbf{u}}_{i}^{(j)} \right),
+
+ where :math:`\mathcal{L}` is the selected loss function, typically the
+ mean squared error.
+ """
+
+ # Accepted conditions types for this solver
+ accepted_conditions_types = (InputTargetCondition,)
+
+ def __init__(
+ self,
+ problem,
+ models,
+ optimizers=None,
+ schedulers=None,
+ weighting=None,
+ loss=None,
+ use_lt=True,
+ ):
+ """
+ Initialization of the :class:`SupervisedEnsembleSolver` class.
+
+ :param BaseProblem problem: The problem to be solved.
+ :param models: The model or list of models used by the solver.
+ :type models: torch.nn.Module | list[torch.nn.Module]
+ :param optimizers: The optimizer or list of optimizers used by the
+ solver. If ``None``, the ``torch.optim.Adam`` optimizer with a
+ learning rate of ``0.001`` is used for each model.
+ Default is ``None``.
+ :type optimizers: TorchOptimizer | list[TorchOptimizer]
+ :param schedulers: The scheduler or list of schedulers used by the
+ solver. If ``None``, the ``torch.optim.lr_scheduler.ConstantLR``
+ scheduler with a factor of ``1.0`` is used for each model.
+ Default is ``None``.
+ :type schedulers: TorchScheduler | list[TorchScheduler]
+ :param BaseWeighting weighting: The weighting strategy used to combine
+ condition losses. If ``None``, no weighting is applied. Default is
+ ``None``.
+ :param loss: The loss function used to compute residual losses.
+ If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``.
+ :param bool use_lt: If ``True``, the solver uses LabelTensors as input.
+ Default is ``True``.
+ """
+ EnsembleSolver.__init__(
+ self,
+ problem=problem,
+ models=models,
+ optimizers=optimizers,
+ schedulers=schedulers,
+ weighting=weighting,
+ loss=loss,
+ use_lt=use_lt,
+ )
diff --git a/pina/_src/solver/supervised_single_model_solver.py b/pina/_src/solver/supervised_single_model_solver.py
new file mode 100644
index 000000000..428de8db4
--- /dev/null
+++ b/pina/_src/solver/supervised_single_model_solver.py
@@ -0,0 +1,74 @@
+"""Module for the supervised single-model solver class."""
+
+from pina._src.condition.input_target_condition import InputTargetCondition
+from pina._src.solver.single_model_solver import SingleModelSolver
+
+
+class SupervisedSingleModelSolver(SingleModelSolver):
+ r"""
+ Single-model solver for supervised learning problems.
+
+ This solver is designed for problems defined by input-target pairs and uses
+ a single model to approximate the mapping from input variables to target
+ variables. It supports only
+ :class:`~pina._src.condition.input_target_condition.InputTargetCondition`
+ conditions.
+
+ Given a model :math:`\mathcal{M}`, the solver minimizes the discrepancy
+ between the target values :math:`\mathbf{u}_i` and the model predictions
+ :math:`\mathcal{M}(\mathbf{s}_i)` evaluated at the input data
+ :math:`\mathbf{s}_i`.
+
+ The supervised loss minimized during training is
+
+ .. math::
+
+ \mathcal{L}_{\mathrm{problem}} = \frac{1}{N} \sum_{i=1}^{N}
+ \mathcal{L} \left( \mathbf{u}_i - \mathcal{M}(\mathbf{s}_i) \right),
+
+ where :math:`\mathcal{L}` is the selected loss function, typically the mean
+ squared error.
+ """
+
+ # Accepted conditions types for this solver
+ accepted_conditions_types = (InputTargetCondition,)
+
+ def __init__(
+ self,
+ problem,
+ model,
+ optimizer=None,
+ scheduler=None,
+ weighting=None,
+ loss=None,
+ use_lt=True,
+ ):
+ """
+ Initialization of the :class:`SupervisedSingleModelSolver` class.
+
+ :param BaseProblem problem: The problem to be solved.
+ :param torch.nn.Module model: The model used by the solver.
+ :param TorchOptimizer optimizer: The optimizer used by the solver.
+ If ``None``, the ``torch.optim.Adam`` optimizer with a learning rate
+ of ``0.001`` is used. Default is ``None``.
+ :param TorchScheduler scheduler: The scheduler used by the solver.
+ If ``None``, the ``torch.optim.lr_scheduler.ConstantLR`` scheduler
+ with a factor of ``1.0`` is used. Default is ``None``.
+ :param BaseWeighting weighting: The weighting strategy used to combine
+ condition losses. If ``None``, no weighting is applied. Default is
+ ``None``.
+ :param loss: The loss function used to compute residual losses.
+ If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``.
+ :param bool use_lt: If ``True``, the solver uses LabelTensors as input.
+ Default is ``True``.
+ """
+ SingleModelSolver.__init__(
+ self,
+ problem=problem,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ weighting=weighting,
+ loss=loss,
+ use_lt=use_lt,
+ )
diff --git a/pina/_src/weighting/__init__.py b/pina/_src/weighting/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pina/_src/weighting/base_weighting.py b/pina/_src/weighting/base_weighting.py
new file mode 100644
index 000000000..2208009cb
--- /dev/null
+++ b/pina/_src/weighting/base_weighting.py
@@ -0,0 +1,109 @@
+"""Module for the Base Weighting class."""
+
+from typing import final, Callable
+import torch
+from pina._src.weighting.weighting_interface import WeightingInterface
+from pina._src.core.utils import check_positive_integer, check_consistency
+
+
+class BaseWeighting(WeightingInterface):
+ """
+ Base class for all weighting schemas, implementing common functionality.
+
+ A weighting schema defines how scalar loss terms coming from different
+ conditions are aggregated into a single scalar loss.
+
+ All weighting schemas must inherit from this class and implement the methods
+ defined in :class:`~pina.weighting.weighting_interface.WeightingInterface`.
+
+ This class is not meant to be instantiated directly.
+ """
+
+ # Supported aggregation methods
+ _AGGREGATE_METHODS = {"sum": torch.sum, "mean": torch.mean}
+
+ def __init__(self, update_every_n_epochs=1, aggregator="sum"):
+ """
+ Initialization of the :class:`BaseWeighting` class.
+
+ :param int update_every_n_epochs: The number of training epochs between
+ weight updates. If set to 1, the weights are updated at every epoch.
+ This parameter is ignored by static weighting schemes.
+ Default is ``1``.
+ :param aggregator: The aggregation method. Available options include:
+ ``"sum"`` which sums the weighted losses, ``"mean"`` which averages
+ the weighted losses, or a custom callable that takes an iterable of
+ weighted losses and returns a single scalar. Default is ``"sum"``.
+ :type aggregator: str | Callable
+ :raises ValueError: If ``update_every_n_epochs`` is not a positive
+ integer.
+ :raises ValueError: If ``aggregator`` is invalid.
+ """
+ # Check consistency
+ check_positive_integer(value=update_every_n_epochs, strict=True)
+ check_consistency(aggregator, (str, Callable))
+
+ # Validate aggregator
+ if isinstance(aggregator, str):
+ if aggregator not in self._AGGREGATE_METHODS:
+ raise ValueError(
+ f"Invalid aggregator '{aggregator}'. Available options: "
+ f"{list(self._AGGREGATE_METHODS.keys())}. Got {aggregator}."
+ )
+ aggregator = self._AGGREGATE_METHODS[aggregator]
+
+ # Initialization
+ self.update_every_n_epochs = update_every_n_epochs
+ self.aggregator_fn = aggregator
+ self._solver = None
+ self._saved_weights = {}
+
+ @final
+ def aggregate(self, losses):
+ """
+ Aggregate a collection of loss terms into a single scalar.
+
+ This method applies the current weighting scheme to the provided losses
+ and returns the aggregated result. Implementations may internally update
+ the weights (e.g., based on training state) before performing the
+ aggregation.
+
+ :param dict losses: The mapping from loss names to loss tensors.
+ :return: The aggregated loss value.
+ :rtype: torch.Tensor
+ """
+ # Update weights when required
+ if self.solver.trainer.current_epoch % self.update_every_n_epochs == 0:
+ self._saved_weights = self.update_weights(losses)
+
+ # Apply weights to the corresponding losses
+ weighted_losses = torch.stack(
+ [
+ (self._saved_weights[condition] * loss).reshape(-1)
+ for condition, loss in losses.items()
+ ]
+ )
+
+ return self.aggregator_fn(weighted_losses)
+
+ def last_saved_weights(self):
+ """
+ Get the most recently computed weights.
+
+ :return: The mapping from loss names to their corresponding weights.
+ :rtype: dict
+ """
+ return self._saved_weights
+
+ @property
+ def solver(self):
+ """
+ Solver associated with this weighting strategy.
+
+ Provides access to the solver instance that uses this weighting scheme,
+ enabling strategies that depend on training state or model information.
+
+ :return: The solver instance.
+ :rtype: :class:`~pina.solver.base_solver.BaseSolver`
+ """
+ return self._solver
diff --git a/pina/_src/weighting/linear_weighting.py b/pina/_src/weighting/linear_weighting.py
new file mode 100644
index 000000000..e57962c81
--- /dev/null
+++ b/pina/_src/weighting/linear_weighting.py
@@ -0,0 +1,79 @@
+"""Module for the Linear Weighting class."""
+
+from pina._src.weighting.base_weighting import BaseWeighting
+from pina._src.core.utils import check_consistency, check_positive_integer
+
+
+class LinearWeighting(BaseWeighting):
+ """
+ Weighting strategy based on linear interpolation over training epochs.
+
+ This scheme progressively adjusts the weights assigned to each loss term,
+ transitioning from a set of initial values to corresponding final values.
+ The update follows a linear schedule and is applied at each epoch until the
+ specified target epoch is reached.
+ """
+
+ def __init__(self, initial_weights, final_weights, target_epoch):
+ """
+ Initialization of the :class:`LinearWeighting` class.
+
+ :param dict initial_weights: The mapping of loss identifiers to their
+ initial weights at the start of training. Keys represent loss terms
+ (e.g., conditions) and values are the corresponding weights. Loss
+ terms not explicitly specified default to a weight of ``1``.
+ :param dict final_weights: The mapping of loss identifiers to their
+ target weights at the specified ``target_epoch``. Keys must match
+ those of ``initial_weights``. Loss terms not explicitly specified
+ default to a weight of ``1``.
+ :param int target_epoch: The epoch at which the weights reach their
+ final values. The interpolation progresses linearly from epoch ``0``
+ to ``target_epoch``. After ``target_epoch``, the weights remain
+ constant at their final values.
+ :raises ValueError: If ``initial_weights`` is not a dictionary.
+ :raises ValueError: If ``final_weights`` is not a dictionary.
+ :raises ValueError: If ``target_epoch`` is not a positive integer.
+ :raises ValueError: If the keys of the two dictionaries are not
+ consistent.
+ """
+ super().__init__(update_every_n_epochs=1, aggregator="sum")
+
+ # Check consistency
+ check_consistency([initial_weights, final_weights], dict)
+ check_positive_integer(value=target_epoch, strict=True)
+
+ # Check that the keys of the two dictionaries match
+ if initial_weights.keys() != final_weights.keys():
+ raise ValueError(
+ "The keys of the provided dictionaries for initial and final "
+ f"weights must match. Got {initial_weights.keys()} as initial "
+ f"weight keys and {final_weights.keys()} as final weight keys."
+ )
+
+ # Initialization
+ self.initial_weights = initial_weights
+ self.final_weights = final_weights
+ self.target_epoch = target_epoch
+
+ def update_weights(self, losses):
+ """
+ Update the weights based on the current losses.
+
+ This method defines how the weighting strategy adapts over time. It is
+ responsible for computing and storing updated weights that will be used
+ during aggregation.
+
+ :param dict losses: The mapping from loss names to loss tensors.
+ :return: The updated weights.
+ :rtype: dict
+ """
+ # Determine the progress towards the target epoch
+ epoch = min(self.solver.trainer.current_epoch, self.target_epoch)
+ progress = epoch / self.target_epoch
+
+ return {
+ condition: self.initial_weights[condition]
+ + (self.final_weights[condition] - self.initial_weights[condition])
+ * progress
+ for condition in losses.keys()
+ }
diff --git a/pina/_src/weighting/no_weighting.py b/pina/_src/weighting/no_weighting.py
new file mode 100644
index 000000000..89507409e
--- /dev/null
+++ b/pina/_src/weighting/no_weighting.py
@@ -0,0 +1,16 @@
+from pina._src.weighting.scalar_weighting import ScalarWeighting
+
+
+class _NoWeighting(ScalarWeighting):
+ """
+ Weighting strategy that leaves all loss terms unchanged.
+
+ This is a special case of scalar weighting where a unit weight is assigned
+ to every loss term, resulting in no reweighting.
+ """
+
+ def __init__(self):
+ """
+ Initialization of the :class:`_NoWeighting` class.
+ """
+ super().__init__(weights=1)
diff --git a/pina/_src/weighting/ntk_weighting.py b/pina/_src/weighting/ntk_weighting.py
new file mode 100644
index 000000000..702d0655c
--- /dev/null
+++ b/pina/_src/weighting/ntk_weighting.py
@@ -0,0 +1,96 @@
+"""Module for Neural Tangent Kernel Class"""
+
+import torch
+from pina._src.weighting.base_weighting import BaseWeighting
+from pina._src.core.utils import check_consistency, in_range
+
+
+class NeuralTangentKernelWeighting(BaseWeighting):
+ """
+ The Neural Tangent Kernel (NTK) weighting strategy.
+
+ This weighting scheme dynamically adjusts the contribution of each loss term
+ during training by leveraging gradient information with respect to the model
+ parameters. For each loss component, the norm of its gradient is computed
+ and used to derive relative importance weights. The resulting weights are
+ smoothed over time using an exponential moving average controlled by the
+ parameter ``alpha``.
+
+ .. seealso::
+
+ **Original reference**: Wang, Sifan, Xinling Yu, and
+ Paris Perdikaris. *When and why PINNs fail to train:
+ A neural tangent kernel perspective*. Journal of
+ Computational Physics 449 (2022): 110768.
+ DOI: `10.1016 `_.
+ """
+
+ def __init__(self, update_every_n_epochs=1, alpha=0.5):
+ """
+ Initialization of the :class:`NeuralTangentKernelWeighting` class.
+
+ :param int update_every_n_epochs: The number of training epochs between
+ weight updates. If set to 1, the weights are updated at every epoch.
+ This parameter is ignored by static weighting schemes.
+ Default is ``1``.
+ :param float alpha: The parameter controlling the exponential moving
+ average of the weights. It must be in the range [0, 1], where a
+ value of ``0.0`` means that only the current gradient norms are used
+ to compute the weights, and a value of ``1.0`` means that only the
+ last saved weights are used. Default is ``0.5``.
+ :raises ValueError: If ``alpha`` is not a float.
+ :raises ValueError: If ``alpha`` is not between 0.0 and 1.0 (inclusive).
+ """
+ super().__init__(
+ update_every_n_epochs=update_every_n_epochs, aggregator="sum"
+ )
+
+ # Check consistency
+ check_consistency(alpha, float)
+ if not in_range(alpha, [0.0, 1.0], strict=False):
+ raise ValueError(
+ "The alpha parameter must be between 0.0 and 1.0 (inclusive)."
+ f" Got {alpha}."
+ )
+
+ # Initialization
+ self.alpha = alpha
+ self.weights = {}
+
+ def update_weights(self, losses):
+ """
+ Update the weights based on the current losses.
+
+ This method defines how the weighting strategy adapts over time. It is
+ responsible for computing and storing updated weights that will be used
+ during aggregation.
+
+ :param dict losses: The mapping from loss names to loss tensors.
+ :return: The updated weights.
+ :rtype: dict
+ """
+ # Get model parameters and define a dictionary to store the norms
+ params = [p for p in self.solver.model.parameters() if p.requires_grad]
+ norms = {}
+
+ # Iterate over conditions
+ for condition, loss in losses.items():
+
+ # Compute gradients
+ grads = torch.autograd.grad(
+ loss,
+ params,
+ retain_graph=True,
+ allow_unused=True,
+ )
+
+ # Compute norms
+ norms[condition] = torch.cat(
+ [g.flatten() for g in grads if g is not None]
+ ).norm()
+
+ return {
+ condition: self.alpha * self.last_saved_weights().get(condition, 1)
+ + (1 - self.alpha) * norms[condition] / sum(norms.values())
+ for condition in losses
+ }
diff --git a/pina/_src/weighting/scalar_weighting.py b/pina/_src/weighting/scalar_weighting.py
new file mode 100644
index 000000000..d977abf67
--- /dev/null
+++ b/pina/_src/weighting/scalar_weighting.py
@@ -0,0 +1,59 @@
+"""Module for the Scalar Weighting."""
+
+from pina._src.weighting.base_weighting import BaseWeighting
+from pina._src.core.utils import check_consistency
+
+
+class ScalarWeighting(BaseWeighting):
+ """
+ Weighting strategy based on fixed scalar coefficients.
+
+ This scheme assigns a constant multiplicative weight to each loss term,
+ without adapting over time. The same weight can be applied to all terms,
+ or distinct weights can be specified for individual conditions.
+ """
+
+ def __init__(self, weights):
+ """
+ Initialization of the :class:`ScalarWeighting` class.
+
+ :param weights: The scalar weights associated with each loss term. It
+ can be provided either as a single numeric value or as a dictionary.
+ If a scalar is given, the same weight is applied to all loss terms.
+ If a dictionary is provided, its keys represent the loss identifiers
+ (e.g., conditions) and its values specify the corresponding weights.
+ Loss terms not explicitly defined in the dictionary are assigned a
+ default weight of ``1``.
+ :type weights: float | int | dict
+ :raises ValueError: If the input weights are neither numeric nor a
+ dictionary.
+ """
+ super().__init__(update_every_n_epochs=1, aggregator="sum")
+
+ # Check consistency
+ check_consistency([weights], (float, dict, int))
+
+ # Initialization
+ if isinstance(weights, dict):
+ self.values = weights
+ self.default_value_weights = 1
+ else:
+ self.values = {}
+ self.default_value_weights = weights
+
+ def update_weights(self, losses):
+ """
+ Update the weights based on the current losses.
+
+ This method defines how the weighting strategy adapts over time. It is
+ responsible for computing and storing updated weights that will be used
+ during aggregation.
+
+ :param dict losses: The mapping from loss names to loss tensors.
+ :return: The updated weights.
+ :rtype: dict
+ """
+ return {
+ condition: self.values.get(condition, self.default_value_weights)
+ for condition in losses.keys()
+ }
diff --git a/pina/loss/self_adaptive_weighting.py b/pina/_src/weighting/self_adaptive_weighting.py
similarity index 54%
rename from pina/loss/self_adaptive_weighting.py
rename to pina/_src/weighting/self_adaptive_weighting.py
index c796d359f..d954fe635 100644
--- a/pina/loss/self_adaptive_weighting.py
+++ b/pina/_src/weighting/self_adaptive_weighting.py
@@ -1,14 +1,22 @@
"""Module for Self-Adaptive Weighting class."""
import torch
-from .weighting_interface import WeightingInterface
+from pina._src.weighting.base_weighting import BaseWeighting
-class SelfAdaptiveWeighting(WeightingInterface):
+class SelfAdaptiveWeighting(BaseWeighting):
"""
- A self-adaptive weighting scheme to tackle the imbalance among the loss
- components. This formulation equalizes the gradient norms of the losses,
- preventing bias toward any particular term during training.
+ The self-adaptive weighting strategy based on gradient norm balancing.
+
+ This scheme dynamically adjusts the weights assigned to each loss term by
+ computing the norm of their gradients with respect to the model parameters.
+ The resulting weights are chosen to counterbalance disparities in gradient
+ magnitudes, promoting a more uniform contribution of all loss components
+ during optimization.
+
+ In practice, loss terms with smaller gradient norms are assigned larger
+ weights, while those with larger gradients are down-weighted. This helps
+ mitigate training imbalance and prevents dominance of specific loss terms.
.. seealso::
@@ -18,7 +26,6 @@ class SelfAdaptiveWeighting(WeightingInterface):
Networks*.
DOI: `arXiv preprint arXiv:2507.08972.
`_
-
"""
def __init__(self, update_every_n_epochs=1):
@@ -27,15 +34,22 @@ def __init__(self, update_every_n_epochs=1):
:param int update_every_n_epochs: The number of training epochs between
weight updates. If set to 1, the weights are updated at every epoch.
- Default is 1.
+ This parameter is ignored by static weighting schemes.
+ Default is ``1``.
"""
- super().__init__(update_every_n_epochs=update_every_n_epochs)
+ super().__init__(
+ update_every_n_epochs=update_every_n_epochs, aggregator="sum"
+ )
- def weights_update(self, losses):
+ def update_weights(self, losses):
"""
- Update the weighting scheme based on the given losses.
+ Update the weights based on the current losses.
+
+ This method defines how the weighting strategy adapts over time. It is
+ responsible for computing and storing updated weights that will be used
+ during aggregation.
- :param dict losses: The dictionary of losses.
+ :param dict losses: The mapping from loss names to loss tensors.
:return: The updated weights.
:rtype: dict
"""
diff --git a/pina/_src/weighting/weighting_interface.py b/pina/_src/weighting/weighting_interface.py
new file mode 100644
index 000000000..7871a5f55
--- /dev/null
+++ b/pina/_src/weighting/weighting_interface.py
@@ -0,0 +1,60 @@
+"""Module for the Weighting Interface."""
+
+from abc import ABCMeta, abstractmethod
+
+
+class WeightingInterface(metaclass=ABCMeta):
+ """
+ Abstract interface for all weighting schemas.
+ """
+
+ @abstractmethod
+ def aggregate(self, losses):
+ """
+ Aggregate a collection of loss terms into a single scalar.
+
+ This method applies the current weighting scheme to the provided losses
+ and returns the aggregated result. Implementations may internally update
+ the weights (e.g., based on training state) before performing the
+ aggregation.
+
+ :param dict losses: The mapping from loss names to loss tensors.
+ :return: The aggregated loss value.
+ :rtype: torch.Tensor
+ """
+
+ @abstractmethod
+ def update_weights(self, losses):
+ """
+ Update the weights based on the current losses.
+
+ This method defines how the weighting strategy adapts over time. It is
+ responsible for computing and storing updated weights that will be used
+ during aggregation.
+
+ :param dict losses: The mapping from loss names to loss tensors.
+ :return: The updated weights.
+ :rtype: dict
+ """
+
+ @abstractmethod
+ def last_saved_weights(self):
+ """
+ Get the most recently computed weights.
+
+ :return: The mapping from loss names to their corresponding weights.
+ :rtype: dict
+ """
+
+ @property
+ @abstractmethod
+ def solver(self):
+ """
+ Solver associated with this weighting strategy.
+
+ Provides access to the solver instance that uses this weighting scheme,
+ enabling strategies that depend on training state or model information.
+
+ :return: The solver instance.
+ :rtype: :class:`~pina.solver.base_solver.BaseSolver`
+ """
diff --git a/pina/adaptive_function/__init__.py b/pina/adaptive_function/__init__.py
index d53c5f368..d41f25ccd 100644
--- a/pina/adaptive_function/__init__.py
+++ b/pina/adaptive_function/__init__.py
@@ -1,33 +1,43 @@
-"""Adaptive Activation Functions Module."""
+"""Adaptive activation functions with learnable parameters.
+
+This module provides implementations of standard activation functions (ReLU,
+SiLU, Tanh, etc.) augmented with trainable weights, as well as specialized
+functions like SIREN, designed to improve convergence in PINNs and Neural
+Operators.
+"""
__all__ = [
- "AdaptiveActivationFunctionInterface",
+ "AdaptiveFunctionInterface",
+ "BaseAdaptiveFunction",
+ "AdaptiveCELU",
+ "AdaptiveELU",
+ "AdaptiveExp",
+ "AdaptiveGELU",
+ "AdaptiveMish",
"AdaptiveReLU",
"AdaptiveSigmoid",
- "AdaptiveTanh",
"AdaptiveSiLU",
- "AdaptiveMish",
- "AdaptiveELU",
- "AdaptiveCELU",
- "AdaptiveGELU",
- "AdaptiveSoftmin",
- "AdaptiveSoftmax",
"AdaptiveSIREN",
- "AdaptiveExp",
+ "AdaptiveSoftmax",
+ "AdaptiveSoftmin",
+ "AdaptiveTanh",
]
-from .adaptive_function import (
- AdaptiveReLU,
- AdaptiveSigmoid,
- AdaptiveTanh,
- AdaptiveSiLU,
- AdaptiveMish,
- AdaptiveELU,
- AdaptiveCELU,
- AdaptiveGELU,
- AdaptiveSoftmin,
- AdaptiveSoftmax,
- AdaptiveSIREN,
- AdaptiveExp,
+from pina._src.adaptive_function.adaptive_function_interface import (
+ AdaptiveFunctionInterface,
+)
+from pina._src.adaptive_function.base_adaptive_function import (
+ BaseAdaptiveFunction,
)
-from .adaptive_function_interface import AdaptiveActivationFunctionInterface
+from pina._src.adaptive_function.adaptive_celu import AdaptiveCELU
+from pina._src.adaptive_function.adaptive_elu import AdaptiveELU
+from pina._src.adaptive_function.adaptive_exp import AdaptiveExp
+from pina._src.adaptive_function.adaptive_gelu import AdaptiveGELU
+from pina._src.adaptive_function.adaptive_mish import AdaptiveMish
+from pina._src.adaptive_function.adaptive_relu import AdaptiveReLU
+from pina._src.adaptive_function.adaptive_sigmoid import AdaptiveSigmoid
+from pina._src.adaptive_function.adaptive_silu import AdaptiveSiLU
+from pina._src.adaptive_function.adaptive_siren import AdaptiveSIREN
+from pina._src.adaptive_function.adaptive_softmax import AdaptiveSoftmax
+from pina._src.adaptive_function.adaptive_softmin import AdaptiveSoftmin
+from pina._src.adaptive_function.adaptive_tanh import AdaptiveTanh
diff --git a/pina/adaptive_function/adaptive_function.py b/pina/adaptive_function/adaptive_function.py
deleted file mode 100644
index e6f86a549..000000000
--- a/pina/adaptive_function/adaptive_function.py
+++ /dev/null
@@ -1,509 +0,0 @@
-"""Module for the Adaptive Functions."""
-
-import torch
-from ..utils import check_consistency
-from .adaptive_function_interface import AdaptiveActivationFunctionInterface
-
-
-class AdaptiveReLU(AdaptiveActivationFunctionInterface):
- r"""
- Adaptive trainable :class:`~torch.nn.ReLU` activation function.
-
- Given the function :math:`\text{ReLU}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
- the adaptive function
- :math:`\text{ReLU}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
- is defined as:
-
- .. math::
- \text{ReLU}_{\text{adaptive}}({x})=\alpha\,\text{ReLU}(\beta{x}+\gamma),
-
- where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters, and the
- ReLU function is defined as:
-
- .. math::
- \text{ReLU}(x) = \max(0, x)
-
- .. seealso::
-
- **Original reference**: Godfrey, Luke B., and Michael S. Gashler.
- *A continuum among logarithmic, linear, and exponential functions,
- and its potential to improve generalization in neural networks.*
- 2015 7th international joint conference on knowledge discovery,
- knowledge engineering and knowledge management (IC3K).
- Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
- `_.
-
- Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
- activation functions accelerate convergence in deep and
- physics-informed neural networks*. Journal of
- Computational Physics 404 (2020): 109136.
- DOI: `JCP 10.1016
- `_.
- """
-
- def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
- super().__init__(alpha, beta, gamma, fixed)
- self._func = torch.nn.ReLU()
-
-
-class AdaptiveSigmoid(AdaptiveActivationFunctionInterface):
- r"""
- Adaptive trainable :class:`~torch.nn.Sigmoid` activation function.
-
- Given the function
- :math:`\text{Sigmoid}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
- the adaptive function
- :math:`\text{Sigmoid}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
- is defined as:
-
- .. math::
- \text{Sigmoid}_{\text{adaptive}}({x})=
- \alpha\,\text{Sigmoid}(\beta{x}+\gamma),
-
- where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters, and the
- Sigmoid function is defined as:
-
- .. math::
- \text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)}
-
- .. seealso::
-
- **Original reference**: Godfrey, Luke B., and Michael S. Gashler.
- *A continuum among logarithmic, linear, and exponential functions,
- and its potential to improve generalization in neural networks.*
- 2015 7th international joint conference on knowledge discovery,
- knowledge engineering and knowledge management (IC3K).
- Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
- `_.
-
- Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
- activation functions accelerate convergence in deep and
- physics-informed neural networks*. Journal of
- Computational Physics 404 (2020): 109136.
- DOI: `JCP 10.1016
- `_.
- """
-
- def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
- super().__init__(alpha, beta, gamma, fixed)
- self._func = torch.nn.Sigmoid()
-
-
-class AdaptiveTanh(AdaptiveActivationFunctionInterface):
- r"""
- Adaptive trainable :class:`~torch.nn.Tanh` activation function.
-
- Given the function :math:`\text{Tanh}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
- the adaptive function
- :math:`\text{Tanh}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
- is defined as:
-
- .. math::
- \text{Tanh}_{\text{adaptive}}({x})=\alpha\,\text{Tanh}(\beta{x}+\gamma),
-
- where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters, and the
- Tanh function is defined as:
-
- .. math::
- \text{Tanh}(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
-
- .. seealso::
-
- **Original reference**: Godfrey, Luke B., and Michael S. Gashler.
- *A continuum among logarithmic, linear, and exponential functions,
- and its potential to improve generalization in neural networks.*
- 2015 7th international joint conference on knowledge discovery,
- knowledge engineering and knowledge management (IC3K).
- Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
- `_.
-
- Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
- activation functions accelerate convergence in deep and
- physics-informed neural networks*. Journal of
- Computational Physics 404 (2020): 109136.
- DOI: `JCP 10.1016
- `_.
- """
-
- def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
- super().__init__(alpha, beta, gamma, fixed)
- self._func = torch.nn.Tanh()
-
-
-class AdaptiveSiLU(AdaptiveActivationFunctionInterface):
- r"""
- Adaptive trainable :class:`~torch.nn.SiLU` activation function.
-
- Given the function :math:`\text{SiLU}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
- the adaptive function
- :math:`\text{SiLU}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
- is defined as:
-
- .. math::
- \text{SiLU}_{\text{adaptive}}({x})=\alpha\,\text{SiLU}(\beta{x}+\gamma),
-
- where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters, and the
- SiLU function is defined as:
-
- .. math::
- \text{SiLU}(x) = x * \sigma(x), \text{where }\sigma(x)
- \text{ is the logistic sigmoid.}
-
- .. seealso::
-
- **Original reference**: Godfrey, Luke B., and Michael S. Gashler.
- *A continuum among logarithmic, linear, and exponential functions,
- and its potential to improve generalization in neural networks.*
- 2015 7th international joint conference on knowledge discovery,
- knowledge engineering and knowledge management (IC3K).
- Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
- `_.
-
- Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
- activation functions accelerate convergence in deep and
- physics-informed neural networks*. Journal of
- Computational Physics 404 (2020): 109136.
- DOI: `JCP 10.1016
- `_.
- """
-
- def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
- super().__init__(alpha, beta, gamma, fixed)
- self._func = torch.nn.SiLU()
-
-
-class AdaptiveMish(AdaptiveActivationFunctionInterface):
- r"""
- Adaptive trainable :class:`~torch.nn.Mish` activation function.
-
- Given the function :math:`\text{Mish}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
- the adaptive function
- :math:`\text{Mish}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
- is defined as:
-
- .. math::
- \text{Mish}_{\text{adaptive}}({x})=\alpha\,\text{Mish}(\beta{x}+\gamma),
-
- where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters, and the
- Mish function is defined as:
-
- .. math::
- \text{Mish}(x) = x * \text{Tanh}(x)
-
- .. seealso::
-
- **Original reference**: Godfrey, Luke B., and Michael S. Gashler.
- *A continuum among logarithmic, linear, and exponential functions,
- and its potential to improve generalization in neural networks.*
- 2015 7th international joint conference on knowledge discovery,
- knowledge engineering and knowledge management (IC3K).
- Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
- `_.
-
- Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
- activation functions accelerate convergence in deep and
- physics-informed neural networks*. Journal of
- Computational Physics 404 (2020): 109136.
- DOI: `JCP 10.1016
- `_.
- """
-
- def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
- super().__init__(alpha, beta, gamma, fixed)
- self._func = torch.nn.Mish()
-
-
-class AdaptiveELU(AdaptiveActivationFunctionInterface):
- r"""
- Adaptive trainable :class:`~torch.nn.ELU` activation function.
-
- Given the function :math:`\text{ELU}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
- the adaptive function
- :math:`\text{ELU}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
- is defined as:
-
- .. math::
- \text{ELU}_{\text{adaptive}}({x}) = \alpha\,\text{ELU}(\beta{x}+\gamma),
-
- where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters, and the
- ELU function is defined as:
-
- .. math::
- \text{ELU}(x) = \begin{cases}
- x, & \text{ if }x > 0\\
- \exp(x) - 1, & \text{ if }x \leq 0
- \end{cases}
-
- .. seealso::
-
- **Original reference**: Godfrey, Luke B., and Michael S. Gashler.
- *A continuum among logarithmic, linear, and exponential functions,
- and its potential to improve generalization in neural networks.*
- 2015 7th international joint conference on knowledge discovery,
- knowledge engineering and knowledge management (IC3K).
- Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
- `_.
-
- Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
- activation functions accelerate convergence in deep and
- physics-informed neural networks*. Journal of
- Computational Physics 404 (2020): 109136.
- DOI: `JCP 10.1016
- `_.
- """
-
- def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
- super().__init__(alpha, beta, gamma, fixed)
- self._func = torch.nn.ELU()
-
-
-class AdaptiveCELU(AdaptiveActivationFunctionInterface):
- r"""
- Adaptive trainable :class:`~torch.nn.CELU` activation function.
-
- Given the function :math:`\text{CELU}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
- the adaptive function
- :math:`\text{CELU}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
- is defined as:
-
- .. math::
- \text{CELU}_{\text{adaptive}}({x})=\alpha\,\text{CELU}(\beta{x}+\gamma),
-
- where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters, and the
- CELU function is defined as:
-
- .. math::
- \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x) - 1))
-
- .. seealso::
-
- **Original reference**: Godfrey, Luke B., and Michael S. Gashler.
- *A continuum among logarithmic, linear, and exponential functions,
- and its potential to improve generalization in neural networks.*
- 2015 7th international joint conference on knowledge discovery,
- knowledge engineering and knowledge management (IC3K).
- Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
- `_.
-
- Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
- activation functions accelerate convergence in deep and
- physics-informed neural networks*. Journal of
- Computational Physics 404 (2020): 109136.
- DOI: `JCP 10.1016
- `_.
- """
-
- def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
- super().__init__(alpha, beta, gamma, fixed)
- self._func = torch.nn.CELU()
-
-
-class AdaptiveGELU(AdaptiveActivationFunctionInterface):
- r"""
- Adaptive trainable :class:`~torch.nn.GELU` activation function.
-
- Given the function :math:`\text{GELU}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
- the adaptive function
- :math:`\text{GELU}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
- is defined as:
-
- .. math::
- \text{GELU}_{\text{adaptive}}({x})=\alpha\,\text{GELU}(\beta{x}+\gamma),
-
- where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters, and the
- GELU function is defined as:
-
- .. math::
- \text{GELU}(x)=0.5*x*(1+\text{Tanh}(\sqrt{2 / \pi}*(x+0.044715*x^3)))
-
-
- .. seealso::
-
- **Original reference**: Godfrey, Luke B., and Michael S. Gashler.
- *A continuum among logarithmic, linear, and exponential functions,
- and its potential to improve generalization in neural networks.*
- 2015 7th international joint conference on knowledge discovery,
- knowledge engineering and knowledge management (IC3K).
- Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
- `_.
-
- Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
- activation functions accelerate convergence in deep and
- physics-informed neural networks*. Journal of
- Computational Physics 404 (2020): 109136.
- DOI: `JCP 10.1016
- `_.
- """
-
- def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
- super().__init__(alpha, beta, gamma, fixed)
- self._func = torch.nn.GELU()
-
-
-class AdaptiveSoftmin(AdaptiveActivationFunctionInterface):
- r"""
- Adaptive trainable :class:`~torch.nn.Softmin` activation function.
-
- Given the function
- :math:`\text{Softmin}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
- the adaptive function
- :math:`\text{Softmin}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
- is defined as:
-
- .. math::
- \text{Softmin}_{\text{adaptive}}({x})=\alpha\,
- \text{Softmin}(\beta{x}+\gamma),
-
- where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters, and the
- Softmin function is defined as:
-
- .. math::
- \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
-
- .. seealso::
-
- **Original reference**: Godfrey, Luke B., and Michael S. Gashler.
- *A continuum among logarithmic, linear, and exponential functions,
- and its potential to improve generalization in neural networks.*
- 2015 7th international joint conference on knowledge discovery,
- knowledge engineering and knowledge management (IC3K).
- Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
- `_.
-
- Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
- activation functions accelerate convergence in deep and
- physics-informed neural networks*. Journal of
- Computational Physics 404 (2020): 109136.
- DOI: `JCP 10.1016
- `_.
- """
-
- def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
- super().__init__(alpha, beta, gamma, fixed)
- self._func = torch.nn.Softmin()
-
-
-class AdaptiveSoftmax(AdaptiveActivationFunctionInterface):
- r"""
- Adaptive trainable :class:`~torch.nn.Softmax` activation function.
-
- Given the function
- :math:`\text{Softmax}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
- the adaptive function
- :math:`\text{Softmax}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
- is defined as:
-
- .. math::
- \text{Softmax}_{\text{adaptive}}({x})=\alpha\,
- \text{Softmax}(\beta{x}+\gamma),
-
- where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters, and the
- Softmax function is defined as:
-
- .. math::
- \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
-
- .. seealso::
-
- **Original reference**: Godfrey, Luke B., and Michael S. Gashler.
- *A continuum among logarithmic, linear, and exponential functions,
- and its potential to improve generalization in neural networks.*
- 2015 7th international joint conference on knowledge discovery,
- knowledge engineering and knowledge management (IC3K).
- Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
- `_.
-
- Jagtap, Ameya D., Kenji Kawaguchi, and George Em Karniadakis. *Adaptive
- activation functions accelerate convergence in deep and
- physics-informed neural networks*. Journal of
- Computational Physics 404 (2020): 109136.
- DOI: `JCP 10.1016
- `_.
- """
-
- def __init__(self, alpha=None, beta=None, gamma=None, fixed=None):
- super().__init__(alpha, beta, gamma, fixed)
- self._func = torch.nn.Softmax()
-
-
-class AdaptiveSIREN(AdaptiveActivationFunctionInterface):
- r"""
- Adaptive trainable :obj:`~torch.sin` function.
-
- Given the function :math:`\text{sin}:\mathbb{R}^n\rightarrow\mathbb{R}^n`,
- the adaptive function
- :math:`\text{sin}_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^n`
- is defined as:
-
- .. math::
- \text{sin}_{\text{adaptive}}({x}) = \alpha\,\text{sin}(\beta{x}+\gamma),
-
- where :math:`\alpha,\,\beta,\,\gamma` are trainable parameters.
-
- .. seealso::
-
- **Original reference**: Godfrey, Luke B., and Michael S. Gashler.
- *A continuum among logarithmic, linear, and exponential functions,
- and its potential to improve generalization in neural networks.*
- 2015 7th international joint conference on knowledge discovery,
- knowledge engineering and knowledge management (IC3K).
- Vol. 1. IEEE, 2015. DOI: `arXiv preprint arXiv:1602.01321.
-