diff --git a/CHANGELOG.md b/CHANGELOG.md index ab5a536d..0021f6b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,11 +10,23 @@ changelog does not include internal changes that do not affect the user. ### Added +- Added `UW` (Uncertainty Weighting) from [Multi-Task Learning Using Uncertainty to Weigh Losses + for Scene Geometry and + Semantics](https://openaccess.thecvf.com/content_cvpr_2018/papers/Kendall_Multi-Task_Learning_Using_CVPR_2018_paper.pdf), + a `Scalarizer` that combines the values using learned per-task uncertainties. It is the first + stateful, trainable scalarizer: its log-variances are an `nn.Parameter` that must be passed to + the optimizer. - Added `GeometricMean` (also known as GLS) studied in [MultiNet++: Multi-Stream Feature Aggregation and Geometric Loss Strategy for Multi-Task Learning](https://openaccess.thecvf.com/content_CVPRW_2019/papers/WAD/Chennupati_MultiNet_Multi-Stream_Feature_Aggregation_and_Geometric_Loss_Strategy_for_Multi-Task_CVPRW_2019_paper.pdf), a `Scalarizer` that returns the geometric mean of the input tensor of values. +### Changed + +- **BREAKING**: Moved the `Stateful` mixin from `torchjd.aggregation` to the top-level `torchjd` + namespace, so it can be shared between the aggregation and scalarization packages. Import it as + `torchjd.Stateful` instead of `torchjd.aggregation.Stateful`. + ## [0.12.0] - 2026-05-28 ### Added diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 13e405cb..4874b9ea 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -19,7 +19,7 @@ Abstract base classes .. autoclass:: torchjd.aggregation.Weighting :members: __call__ -.. autoclass:: torchjd.aggregation.Stateful +.. autoclass:: torchjd.Stateful :members: reset diff --git a/docs/source/docs/scalarization/index.rst b/docs/source/docs/scalarization/index.rst index 8fd87dc8..290587de 100644 --- a/docs/source/docs/scalarization/index.rst +++ b/docs/source/docs/scalarization/index.rst @@ -20,3 +20,4 @@ Abstract base class mean.rst random.rst sum.rst + uw.rst diff --git a/docs/source/docs/scalarization/uw.rst b/docs/source/docs/scalarization/uw.rst new file mode 100644 index 00000000..f797a4c7 --- /dev/null +++ b/docs/source/docs/scalarization/uw.rst @@ -0,0 +1,7 @@ +:hide-toc: + +UW +== + +.. autoclass:: torchjd.scalarization.UW + :members: __call__ diff --git a/docs/source/examples/grouping.rst b/docs/source/examples/grouping.rst index dff8d127..1e5d8293 100644 --- a/docs/source/examples/grouping.rst +++ b/docs/source/examples/grouping.rst @@ -19,7 +19,7 @@ the parameters: In TorchJD, grouping is achieved by calling :func:`~torchjd.autojac.jac_to_grad` once per group after :func:`~torchjd.autojac.backward` or :func:`~torchjd.autojac.mtl_backward`, with a dedicated -aggregator instance per group. For :class:`~torchjd.aggregation.Stateful` aggregators, each instance +aggregator instance per group. For :class:`~torchjd._mixins.Stateful` aggregators, each instance should independently maintain its own state (e.g. the EMA :math:`\hat{\phi}` state in :class:`~torchjd.aggregation.GradVac`, matching the per-block targets from the original paper). diff --git a/src/torchjd/__init__.py b/src/torchjd/__init__.py index 4253561a..7e241ced 100644 --- a/src/torchjd/__init__.py +++ b/src/torchjd/__init__.py @@ -1,8 +1,11 @@ from collections.abc import Callable from warnings import warn as _warn +from ._mixins import Stateful from .autojac import backward as _backward, mtl_backward as _mtl_backward +__all__ = ["Stateful"] + _deprecated_items: dict[str, tuple[str, Callable]] = { "backward": ("autojac", _backward), "mtl_backward": ("autojac", _mtl_backward), diff --git a/src/torchjd/_mixins.py b/src/torchjd/_mixins.py index 60b7ee8c..56ed6530 100644 --- a/src/torchjd/_mixins.py +++ b/src/torchjd/_mixins.py @@ -1,7 +1,16 @@ +from abc import ABC, abstractmethod from importlib.util import find_spec from typing import Any +class Stateful(ABC): + """Mixin adding a reset method.""" + + @abstractmethod + def reset(self) -> None: + """Resets the internal state.""" + + class _WithOptionalDeps: """ Mixin that raises :class:`ImportError` at instantiation time if required optional dependencies diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 1814d320..558d828b 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -52,7 +52,6 @@ from ._krum import Krum, KrumWeighting from ._mean import Mean, MeanWeighting from ._mgda import MGDA, MGDAWeighting -from ._mixins import Stateful from ._nash_mtl import NashMTL from ._pcgrad import PCGrad, PCGradWeighting from ._random import Random, RandomWeighting @@ -92,7 +91,6 @@ "PCGradWeighting", "Random", "RandomWeighting", - "Stateful", "Sum", "SumWeighting", "TrimmedMean", diff --git a/src/torchjd/aggregation/_cr_mogm.py b/src/torchjd/aggregation/_cr_mogm.py index d056c5ec..31146e3e 100644 --- a/src/torchjd/aggregation/_cr_mogm.py +++ b/src/torchjd/aggregation/_cr_mogm.py @@ -4,7 +4,7 @@ from torch import Tensor -from torchjd.aggregation._mixins import Stateful +from torchjd._mixins import Stateful from ._weighting_bases import Weighting @@ -13,7 +13,7 @@ class CRMOGMWeighting(Weighting[_T], Stateful): r""" - :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd._mixins.Stateful` :class:`~torchjd.aggregation._weighting_bases.Weighting` that wraps another :class:`~torchjd.aggregation._weighting_bases.Weighting` and stabilises the weights it produces with an exponential moving average (EMA) across calls. This is the weight-smoothing @@ -61,7 +61,7 @@ class CRMOGMWeighting(Weighting[_T], Stateful): This weighting is stateful: it keeps :math:`\lambda_{k-1}` across calls. Use :meth:`reset` to restart the smoothing from the initial state. Note that calling :meth:`reset` will also - reset the wrapped weighting if it is :class:`~torchjd.aggregation.Stateful`. + reset the wrapped weighting if it is :class:`~torchjd._mixins.Stateful`. :param weighting: The wrapped weighting whose output is smoothed. :param alpha: EMA coefficient on the previous weights. ``alpha=0`` disables smoothing @@ -120,7 +120,7 @@ def alpha(self, value: float) -> None: def reset(self) -> None: r""" Clears the EMA state so the next forward restarts from the initial state. Also resets the - wrapped weighting if it is :class:`~torchjd.aggregation._mixins.Stateful`. + wrapped weighting if it is :class:`~torchjd._mixins.Stateful`. """ if isinstance(self.weighting, Stateful): diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index 8a031e64..0ad7436d 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -5,7 +5,8 @@ import torch from torch import Tensor -from torchjd.aggregation._mixins import Stateful, _NonDifferentiable +from torchjd._mixins import Stateful +from torchjd.aggregation._mixins import _NonDifferentiable from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator @@ -15,7 +16,7 @@ # Non-differentiable: weights are modified in-place during the gradient correction loop. class GradVacWeighting(_GramianWeighting, Stateful, _NonDifferentiable): r""" - :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd._mixins.Stateful` :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.GradVac`. @@ -130,7 +131,7 @@ def _ensure_state(self, m: int, dtype: torch.dtype) -> None: class GradVac(GramianWeightedAggregator, Stateful, _NonDifferentiable): r""" - :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd._mixins.Stateful` :class:`~torchjd.aggregation.GramianWeightedAggregator` implementing the aggregation step of Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models (ICLR 2021 Spotlight) diff --git a/src/torchjd/aggregation/_mixins.py b/src/torchjd/aggregation/_mixins.py index b906140b..8f37b62e 100644 --- a/src/torchjd/aggregation/_mixins.py +++ b/src/torchjd/aggregation/_mixins.py @@ -1,18 +1,9 @@ -from abc import ABC, abstractmethod from typing import Any import torch from torch import nn -class Stateful(ABC): - """Mixin adding a reset method.""" - - @abstractmethod - def reset(self) -> None: - """Resets the internal state.""" - - class _NonDifferentiable(nn.Module): """ Mixin making a nn.Module non-differentiable, preventing autograd graph construction by wrapping diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 879b4221..ae1c2546 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -8,8 +8,8 @@ import torch from torch import Tensor -from torchjd._mixins import _WithOptionalDeps -from torchjd.aggregation._mixins import Stateful, _NonDifferentiable +from torchjd._mixins import Stateful, _WithOptionalDeps +from torchjd.aggregation._mixins import _NonDifferentiable from ._aggregator_bases import WeightedAggregator from ._weighting_bases import _MatrixWeighting @@ -25,7 +25,7 @@ class _NashMTLWeighting(_WithOptionalDeps, _MatrixWeighting, Stateful, _NonDiffe _REQUIRED_DEPS = ["numpy", "cvxpy", "ecos"] _INSTALL_HINT = 'Install them with: pip install "torchjd[nash_mtl]"' """ - :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd._mixins.Stateful` :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] that extracts weights using the step decision of Algorithm 1 of `Multi-Task Learning as a Bargaining Game `_. @@ -206,7 +206,7 @@ def reset(self) -> None: class NashMTL(WeightedAggregator, Stateful, _NonDifferentiable): """ - :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd._mixins.Stateful` :class:`~torchjd.aggregation.WeightedAggregator` as proposed in Algorithm 1 of `Multi-Task Learning as a Bargaining Game `_. diff --git a/src/torchjd/scalarization/__init__.py b/src/torchjd/scalarization/__init__.py index 337d38ca..d50cb606 100644 --- a/src/torchjd/scalarization/__init__.py +++ b/src/torchjd/scalarization/__init__.py @@ -25,5 +25,6 @@ from ._random import Random from ._scalarizer_base import Scalarizer from ._sum import Sum +from ._uw import UW -__all__ = ["Constant", "GeometricMean", "Mean", "Random", "Scalarizer", "Sum"] +__all__ = ["Constant", "GeometricMean", "Mean", "Random", "Scalarizer", "Sum", "UW"] diff --git a/src/torchjd/scalarization/_uw.py b/src/torchjd/scalarization/_uw.py new file mode 100644 index 00000000..6b1f10d0 --- /dev/null +++ b/src/torchjd/scalarization/_uw.py @@ -0,0 +1,80 @@ +from collections.abc import Sequence + +import torch +from torch import Tensor, nn + +from torchjd._mixins import Stateful + +from ._scalarizer_base import Scalarizer + + +class UW(Scalarizer, Stateful): + r""" + :class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of values using + learned per-task uncertainties. ``UW`` is short for Uncertainty Weighting, the method proposed + in `Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics + `_. + + Each value :math:`L_i` is assigned a learnable log-variance :math:`s_i`, and the values are + combined as + + .. math:: + \sum_i \left( \frac{1}{2} e^{-s_i} L_i + \frac{1}{2} s_i \right) + + where: + + - :math:`L_i` is the :math:`i`-th value (typically the loss of task :math:`i`); + - :math:`s_i = \log \sigma_i^2` is the learnable log-variance of task :math:`i`. + + Following the paper, the log-variance :math:`s_i` is learned rather than the variance + :math:`\sigma_i^2` directly: this is numerically more stable (the combination never divides by + zero) and keeps :math:`s_i` unconstrained, since :math:`e^{-s_i}` is always positive. The + :math:`s_i` are stored as an ``nn.Parameter``, so the parameters of this scalarizer must be + passed to the optimizer to be learned jointly with the model. + + :param shape: The shape of the values to scalarize, used to create one log-variance per value. + An ``int`` ``n`` is interpreted as the shape ``(n,)``. + + The following example shows how to co-train a model together with the per-task log-variances, by + passing both sets of parameters to the optimizer. + + >>> import torch + >>> from torch.nn import Linear + >>> + >>> from torchjd.scalarization import UW + >>> + >>> model = Linear(3, 2) + >>> scalarizer = UW(2) + >>> optimizer = torch.optim.SGD([*model.parameters(), *scalarizer.parameters()], lr=0.1) + >>> + >>> features = torch.randn(8, 3) + >>> losses = model(features).pow(2).mean(dim=0) # One loss per output dimension. + >>> loss = scalarizer(losses) + >>> loss.backward() + >>> optimizer.step() + + .. note:: + The log-variances are initialized to ``0`` (i.e. :math:`\sigma_i^2 = 1`), which gives + uniform weights at the start of training. The paper reports that the result is robust to + this initialization. (`LibMTL `_ + initializes them to ``-0.5`` instead.) + """ + + def __init__(self, shape: int | Sequence[int]) -> None: + super().__init__() + self.log_var = nn.Parameter(torch.zeros(shape)) + + def forward(self, values: Tensor, /) -> Tensor: + if values.shape != self.log_var.shape: + raise ValueError( + f"Parameter `values` should have shape {tuple(self.log_var.shape)} (matching the " + f"shape of the log-variances). Found `values.shape = {tuple(values.shape)}`.", + ) + return (0.5 * torch.exp(-self.log_var) * values + 0.5 * self.log_var).sum() + + def reset(self) -> None: + with torch.no_grad(): + self.log_var.zero_() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(shape={tuple(self.log_var.shape)})" diff --git a/tests/trajectories/optimize.py b/tests/trajectories/optimize.py index 1e78b480..f7cab08d 100644 --- a/tests/trajectories/optimize.py +++ b/tests/trajectories/optimize.py @@ -19,7 +19,7 @@ import torch from tests.paths import TRAJECTORIES_RESULTS_DIR -from torchjd.aggregation import Stateful +from torchjd import Stateful from trajectories._constants import ( AGGREGATORS, BASE_LEARNING_RATES, diff --git a/tests/unit/aggregation/test_cr_mogm.py b/tests/unit/aggregation/test_cr_mogm.py index b01945b3..fb2da79d 100644 --- a/tests/unit/aggregation/test_cr_mogm.py +++ b/tests/unit/aggregation/test_cr_mogm.py @@ -56,7 +56,7 @@ def test_reset_restores_first_step_behavior() -> None: def test_reset_propagates_to_stateful_weighting() -> None: """ Verify that ``reset()`` calls the wrapped weighting's ``reset()`` when it is - :class:`~torchjd.aggregation.Stateful`. Checks that ``GradVacWeighting``'s internal + :class:`~torchjd._mixins.Stateful`. Checks that ``GradVacWeighting``'s internal state is cleared after ``reset()``. """ diff --git a/tests/unit/scalarization/test_uw.py b/tests/unit/scalarization/test_uw.py new file mode 100644 index 00000000..21bc3961 --- /dev/null +++ b/tests/unit/scalarization/test_uw.py @@ -0,0 +1,100 @@ +from contextlib import nullcontext as does_not_raise + +import torch +from pytest import mark, raises +from settings import DEVICE, DTYPE +from torch import Tensor +from utils.contexts import ExceptionContext +from utils.tensors import ones_, tensor_, zeros_ + +from torchjd.scalarization import UW + +from ._asserts import assert_grad_flow, assert_returns_scalar +from ._inputs import all_inputs + + +def _uw(shape: int | tuple[int, ...]) -> UW: + """Builds a `UW` whose log-variances live on the test device and dtype.""" + return UW(shape).to(device=DEVICE, dtype=DTYPE) + + +def test_value() -> None: + # With log-variances initialized to 0, the result is 0.5 * sum(values). + values = tensor_([1.0, 2.0, 4.0]) + torch.testing.assert_close(_uw((3,))(values), tensor_(3.5)) + + +def test_int_shape_matches_tuple_shape() -> None: + values = tensor_([1.0, 2.0, 4.0]) + assert UW(3).log_var.shape == (3,) + torch.testing.assert_close(_uw(3)(values), _uw((3,))(values)) + + +@mark.parametrize("values", all_inputs) +def test_expected_structure(values: Tensor) -> None: + assert_returns_scalar(_uw(tuple(values.shape)), values) + + +@mark.parametrize("values", all_inputs) +def test_grad_flow(values: Tensor) -> None: + assert_grad_flow(_uw(tuple(values.shape)), values) + + +@mark.parametrize("values", all_inputs) +def test_grad_flows_to_log_var(values: Tensor) -> None: + scalarizer = _uw(tuple(values.shape)) + scalarizer(values).backward() + assert scalarizer.log_var.grad is not None + assert scalarizer.log_var.grad.isfinite().all() + + +@mark.parametrize( + ["param_shape", "values_shape", "expectation"], + [ + ((5,), (5,), does_not_raise()), + ((3, 4), (3, 4), does_not_raise()), + ((), (), does_not_raise()), + ((5,), (4,), raises(ValueError)), + ((5,), (5, 1), raises(ValueError)), + ((3, 4), (4, 3), raises(ValueError)), + ], +) +def test_shape_check( + param_shape: tuple[int, ...], + values_shape: tuple[int, ...], + expectation: ExceptionContext, +) -> None: + scalarizer = _uw(param_shape) + values = ones_(values_shape) + with expectation: + _ = scalarizer(values) + + +def test_reset_restores_initial_log_var() -> None: + scalarizer = _uw((3,)) + with torch.no_grad(): + scalarizer.log_var.add_(1.0) + scalarizer.reset() + torch.testing.assert_close(scalarizer.log_var.detach(), zeros_((3,))) + + +def test_does_not_raise_on_negative_input() -> None: + # Unlike GeometricMean, UW has no positivity precondition. + values = tensor_([-1.0, -2.0, 3.0]) + assert_returns_scalar(_uw((3,)), values) + + +def test_is_trainable() -> None: + scalarizer = _uw((2,)) + optimizer = torch.optim.SGD(scalarizer.parameters(), lr=0.1) + values = tensor_([2.0, 5.0]) + optimizer.zero_grad() + scalarizer(values).backward() + optimizer.step() + assert not torch.equal(scalarizer.log_var.detach(), zeros_((2,))) + + +def test_representations() -> None: + assert repr(UW(3)) == "UW(shape=(3,))" + assert repr(UW((2, 3))) == "UW(shape=(2, 3))" + assert str(UW(3)) == "UW"