Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,23 @@ changelog does not include internal changes that do not affect the user.

### Added

- Added `UW` (Uncertainty Weighting) from [Multi-Task Learning Using Uncertainty to Weigh Losses
for Scene Geometry and
Semantics](https://openaccess.thecvf.com/content_cvpr_2018/papers/Kendall_Multi-Task_Learning_Using_CVPR_2018_paper.pdf),
a `Scalarizer` that combines the values using learned per-task uncertainties. It is the first
stateful, trainable scalarizer: its log-variances are an `nn.Parameter` that must be passed to
the optimizer.
- Added `GeometricMean` (also known as GLS) studied in [MultiNet++: Multi-Stream Feature
Aggregation and Geometric Loss Strategy for Multi-Task
Learning](https://openaccess.thecvf.com/content_CVPRW_2019/papers/WAD/Chennupati_MultiNet_Multi-Stream_Feature_Aggregation_and_Geometric_Loss_Strategy_for_Multi-Task_CVPRW_2019_paper.pdf),
a `Scalarizer` that returns the geometric mean of the input tensor of values.

### Changed

- **BREAKING**: Moved the `Stateful` mixin from `torchjd.aggregation` to the top-level `torchjd`
namespace, so it can be shared between the aggregation and scalarization packages. Import it as
`torchjd.Stateful` instead of `torchjd.aggregation.Stateful`.

## [0.12.0] - 2026-05-28

### Added
Expand Down
2 changes: 1 addition & 1 deletion docs/source/docs/aggregation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Abstract base classes
.. autoclass:: torchjd.aggregation.Weighting
:members: __call__

.. autoclass:: torchjd.aggregation.Stateful
.. autoclass:: torchjd.Stateful
:members: reset


Expand Down
1 change: 1 addition & 0 deletions docs/source/docs/scalarization/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ Abstract base class
mean.rst
random.rst
sum.rst
uw.rst
7 changes: 7 additions & 0 deletions docs/source/docs/scalarization/uw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:hide-toc:

UW
==

.. autoclass:: torchjd.scalarization.UW
:members: __call__
2 changes: 1 addition & 1 deletion docs/source/examples/grouping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ the parameters:

In TorchJD, grouping is achieved by calling :func:`~torchjd.autojac.jac_to_grad` once per group
after :func:`~torchjd.autojac.backward` or :func:`~torchjd.autojac.mtl_backward`, with a dedicated
aggregator instance per group. For :class:`~torchjd.aggregation.Stateful` aggregators, each instance
aggregator instance per group. For :class:`~torchjd._mixins.Stateful` aggregators, each instance
should independently maintain its own state (e.g. the EMA :math:`\hat{\phi}` state in
:class:`~torchjd.aggregation.GradVac`, matching the per-block targets from the original paper).

Expand Down
3 changes: 3 additions & 0 deletions src/torchjd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from collections.abc import Callable
from warnings import warn as _warn

from ._mixins import Stateful
from .autojac import backward as _backward, mtl_backward as _mtl_backward

__all__ = ["Stateful"]

_deprecated_items: dict[str, tuple[str, Callable]] = {
"backward": ("autojac", _backward),
"mtl_backward": ("autojac", _mtl_backward),
Expand Down
9 changes: 9 additions & 0 deletions src/torchjd/_mixins.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from abc import ABC, abstractmethod
from importlib.util import find_spec
from typing import Any


class Stateful(ABC):
"""Mixin adding a reset method."""

@abstractmethod
def reset(self) -> None:
"""Resets the internal state."""


class _WithOptionalDeps:
"""
Mixin that raises :class:`ImportError` at instantiation time if required optional dependencies
Expand Down
2 changes: 0 additions & 2 deletions src/torchjd/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
from ._krum import Krum, KrumWeighting
from ._mean import Mean, MeanWeighting
from ._mgda import MGDA, MGDAWeighting
from ._mixins import Stateful
from ._nash_mtl import NashMTL
from ._pcgrad import PCGrad, PCGradWeighting
from ._random import Random, RandomWeighting
Expand Down Expand Up @@ -92,7 +91,6 @@
"PCGradWeighting",
"Random",
"RandomWeighting",
"Stateful",
"Sum",
"SumWeighting",
"TrimmedMean",
Expand Down
8 changes: 4 additions & 4 deletions src/torchjd/aggregation/_cr_mogm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from torch import Tensor

from torchjd.aggregation._mixins import Stateful
from torchjd._mixins import Stateful

from ._weighting_bases import Weighting

Expand All @@ -13,7 +13,7 @@

class CRMOGMWeighting(Weighting[_T], Stateful):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd._mixins.Stateful`
:class:`~torchjd.aggregation._weighting_bases.Weighting` that wraps another
:class:`~torchjd.aggregation._weighting_bases.Weighting` and stabilises the weights it
produces with an exponential moving average (EMA) across calls. This is the weight-smoothing
Expand Down Expand Up @@ -61,7 +61,7 @@ class CRMOGMWeighting(Weighting[_T], Stateful):

This weighting is stateful: it keeps :math:`\lambda_{k-1}` across calls. Use :meth:`reset`
to restart the smoothing from the initial state. Note that calling :meth:`reset` will also
reset the wrapped weighting if it is :class:`~torchjd.aggregation.Stateful`.
reset the wrapped weighting if it is :class:`~torchjd._mixins.Stateful`.

:param weighting: The wrapped weighting whose output is smoothed.
:param alpha: EMA coefficient on the previous weights. ``alpha=0`` disables smoothing
Expand Down Expand Up @@ -120,7 +120,7 @@ def alpha(self, value: float) -> None:
def reset(self) -> None:
r"""
Clears the EMA state so the next forward restarts from the initial state. Also resets the
wrapped weighting if it is :class:`~torchjd.aggregation._mixins.Stateful`.
wrapped weighting if it is :class:`~torchjd._mixins.Stateful`.
"""

if isinstance(self.weighting, Stateful):
Expand Down
7 changes: 4 additions & 3 deletions src/torchjd/aggregation/_gradvac.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torch
from torch import Tensor

from torchjd.aggregation._mixins import Stateful, _NonDifferentiable
from torchjd._mixins import Stateful
from torchjd.aggregation._mixins import _NonDifferentiable
from torchjd.linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
Expand All @@ -15,7 +16,7 @@
# Non-differentiable: weights are modified in-place during the gradient correction loop.
class GradVacWeighting(_GramianWeighting, Stateful, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd._mixins.Stateful`
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.GradVac`.

Expand Down Expand Up @@ -130,7 +131,7 @@ def _ensure_state(self, m: int, dtype: torch.dtype) -> None:

class GradVac(GramianWeightedAggregator, Stateful, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd._mixins.Stateful`
:class:`~torchjd.aggregation.GramianWeightedAggregator` implementing the aggregation step of
Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task
Optimization in Massively Multilingual Models (ICLR 2021 Spotlight)
Expand Down
9 changes: 0 additions & 9 deletions src/torchjd/aggregation/_mixins.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
from abc import ABC, abstractmethod
from typing import Any

import torch
from torch import nn


class Stateful(ABC):
"""Mixin adding a reset method."""

@abstractmethod
def reset(self) -> None:
"""Resets the internal state."""


class _NonDifferentiable(nn.Module):
"""
Mixin making a nn.Module non-differentiable, preventing autograd graph construction by wrapping
Expand Down
8 changes: 4 additions & 4 deletions src/torchjd/aggregation/_nash_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import torch
from torch import Tensor

from torchjd._mixins import _WithOptionalDeps
from torchjd.aggregation._mixins import Stateful, _NonDifferentiable
from torchjd._mixins import Stateful, _WithOptionalDeps
from torchjd.aggregation._mixins import _NonDifferentiable

from ._aggregator_bases import WeightedAggregator
from ._weighting_bases import _MatrixWeighting
Expand All @@ -25,7 +25,7 @@ class _NashMTLWeighting(_WithOptionalDeps, _MatrixWeighting, Stateful, _NonDiffe
_REQUIRED_DEPS = ["numpy", "cvxpy", "ecos"]
_INSTALL_HINT = 'Install them with: pip install "torchjd[nash_mtl]"'
"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd._mixins.Stateful`
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] that
extracts weights using the step decision of Algorithm 1 of `Multi-Task Learning as a Bargaining
Game <https://arxiv.org/pdf/2202.01017.pdf>`_.
Expand Down Expand Up @@ -206,7 +206,7 @@ def reset(self) -> None:

class NashMTL(WeightedAggregator, Stateful, _NonDifferentiable):
"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd._mixins.Stateful`
:class:`~torchjd.aggregation.WeightedAggregator` as proposed in Algorithm 1 of
`Multi-Task Learning as a Bargaining Game <https://arxiv.org/pdf/2202.01017.pdf>`_.

Expand Down
3 changes: 2 additions & 1 deletion src/torchjd/scalarization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@
from ._random import Random
from ._scalarizer_base import Scalarizer
from ._sum import Sum
from ._uw import UW

__all__ = ["Constant", "GeometricMean", "Mean", "Random", "Scalarizer", "Sum"]
__all__ = ["Constant", "GeometricMean", "Mean", "Random", "Scalarizer", "Sum", "UW"]
80 changes: 80 additions & 0 deletions src/torchjd/scalarization/_uw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from collections.abc import Sequence

import torch
from torch import Tensor, nn

from torchjd._mixins import Stateful

from ._scalarizer_base import Scalarizer


class UW(Scalarizer, Stateful):
r"""
:class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of values using
learned per-task uncertainties. ``UW`` is short for Uncertainty Weighting, the method proposed
in `Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics
<https://openaccess.thecvf.com/content_cvpr_2018/papers/Kendall_Multi-Task_Learning_Using_CVPR_2018_paper.pdf>`_.

Each value :math:`L_i` is assigned a learnable log-variance :math:`s_i`, and the values are
combined as

.. math::
\sum_i \left( \frac{1}{2} e^{-s_i} L_i + \frac{1}{2} s_i \right)

where:

- :math:`L_i` is the :math:`i`-th value (typically the loss of task :math:`i`);
- :math:`s_i = \log \sigma_i^2` is the learnable log-variance of task :math:`i`.

Following the paper, the log-variance :math:`s_i` is learned rather than the variance
:math:`\sigma_i^2` directly: this is numerically more stable (the combination never divides by
zero) and keeps :math:`s_i` unconstrained, since :math:`e^{-s_i}` is always positive. The
:math:`s_i` are stored as an ``nn.Parameter``, so the parameters of this scalarizer must be
passed to the optimizer to be learned jointly with the model.

:param shape: The shape of the values to scalarize, used to create one log-variance per value.
An ``int`` ``n`` is interpreted as the shape ``(n,)``.

The following example shows how to co-train a model together with the per-task log-variances, by
passing both sets of parameters to the optimizer.

>>> import torch
>>> from torch.nn import Linear
>>>
>>> from torchjd.scalarization import UW
>>>
>>> model = Linear(3, 2)
>>> scalarizer = UW(2)
>>> optimizer = torch.optim.SGD([*model.parameters(), *scalarizer.parameters()], lr=0.1)
>>>
>>> features = torch.randn(8, 3)
>>> losses = model(features).pow(2).mean(dim=0) # One loss per output dimension.
>>> loss = scalarizer(losses)
>>> loss.backward()
>>> optimizer.step()

.. note::
The log-variances are initialized to ``0`` (i.e. :math:`\sigma_i^2 = 1`), which gives
uniform weights at the start of training. The paper reports that the result is robust to
this initialization. (`LibMTL <https://github.com/median-research-group/LibMTL>`_
initializes them to ``-0.5`` instead.)
"""

def __init__(self, shape: int | Sequence[int]) -> None:
super().__init__()
self.log_var = nn.Parameter(torch.zeros(shape))

def forward(self, values: Tensor, /) -> Tensor:
if values.shape != self.log_var.shape:
raise ValueError(
f"Parameter `values` should have shape {tuple(self.log_var.shape)} (matching the "
f"shape of the log-variances). Found `values.shape = {tuple(values.shape)}`.",
)
return (0.5 * torch.exp(-self.log_var) * values + 0.5 * self.log_var).sum()

def reset(self) -> None:
with torch.no_grad():
self.log_var.zero_()

def __repr__(self) -> str:
return f"{self.__class__.__name__}(shape={tuple(self.log_var.shape)})"
2 changes: 1 addition & 1 deletion tests/trajectories/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch

from tests.paths import TRAJECTORIES_RESULTS_DIR
from torchjd.aggregation import Stateful
from torchjd import Stateful
from trajectories._constants import (
AGGREGATORS,
BASE_LEARNING_RATES,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/aggregation/test_cr_mogm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_reset_restores_first_step_behavior() -> None:
def test_reset_propagates_to_stateful_weighting() -> None:
"""
Verify that ``reset()`` calls the wrapped weighting's ``reset()`` when it is
:class:`~torchjd.aggregation.Stateful`. Checks that ``GradVacWeighting``'s internal
:class:`~torchjd._mixins.Stateful`. Checks that ``GradVacWeighting``'s internal
state is cleared after ``reset()``.
"""

Expand Down
Loading
Loading