From b981d098ef3828e94880f8e10233921a810e45f5 Mon Sep 17 00:00:00 2001 From: Khush Date: Thu, 28 May 2026 23:47:35 -0400 Subject: [PATCH 1/5] feat(aggregation): Add MoDoWeighting --- CHANGELOG.md | 6 + docs/source/docs/aggregation/index.rst | 1 + docs/source/docs/aggregation/modo.rst | 7 ++ src/torchjd/aggregation/__init__.py | 2 + src/torchjd/aggregation/_modo.py | 143 +++++++++++++++++++++++ tests/unit/aggregation/test_modo.py | 153 +++++++++++++++++++++++++ 6 files changed, 312 insertions(+) create mode 100644 docs/source/docs/aggregation/modo.rst create mode 100644 src/torchjd/aggregation/_modo.py create mode 100644 tests/unit/aggregation/test_modo.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 9204433a..7ca370c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,12 @@ changelog does not include internal changes that do not affect the user. ## [Unreleased] +### Added + +- Added `MoDoWeighting` from [Three-Way Trade-Off in Multi-Objective Learning: Optimization,Generalization and Conflict-Avoidance](https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf) (JMLR 2024). It is a stateful `Weighting` that maintains task weights across calls via a + softmax-projected gradient step on the Gramian, intended to be composed with `autogram.Engine` + in a two-batch training loop. + ## [0.12.0] - 2026-05-28 ### Added diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 13e405cb..66d74570 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -41,6 +41,7 @@ Abstract base classes krum.rst mean.rst mgda.rst + modo.rst nash_mtl.rst pcgrad.rst random.rst diff --git a/docs/source/docs/aggregation/modo.rst b/docs/source/docs/aggregation/modo.rst new file mode 100644 index 00000000..98b8d515 --- /dev/null +++ b/docs/source/docs/aggregation/modo.rst @@ -0,0 +1,7 @@ +:hide-toc: + +MoDo +==== + +.. autoclass:: torchjd.aggregation.MoDoWeighting + :members: __call__, reset diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 1814d320..92bbadec 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -53,6 +53,7 @@ from ._mean import Mean, MeanWeighting from ._mgda import MGDA, MGDAWeighting from ._mixins import Stateful +from ._modo import MoDoWeighting from ._nash_mtl import NashMTL from ._pcgrad import PCGrad, PCGradWeighting from ._random import Random, RandomWeighting @@ -87,6 +88,7 @@ "MeanWeighting", "MGDA", "MGDAWeighting", + "MoDoWeighting", "NashMTL", "PCGrad", "PCGradWeighting", diff --git a/src/torchjd/aggregation/_modo.py b/src/torchjd/aggregation/_modo.py new file mode 100644 index 00000000..d24629b3 --- /dev/null +++ b/src/torchjd/aggregation/_modo.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +from typing import cast + +import torch +from torch import Tensor + +from torchjd.aggregation._mixins import Stateful +from torchjd.linalg import PSDMatrix + +from ._weighting_bases import _GramianWeighting + + +class MoDoWeighting(_GramianWeighting, Stateful): + r""" + :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] implementing the + task-weight update from `Three-Way Trade-Off in Multi-Objective Learning: Optimization, + Generalization and Conflict-Avoidance `_ + (JMLR 2024), commonly referred to as MoDo (Multi-Objective gradient with Double sampling). + + At each call, the weights :math:`\lambda` are updated by a projected gradient step on + :math:`\lambda^\top G \lambda + \rho \|\lambda\|^2` where :math:`G = G_1 G_1^\top` is the + Gramian of the first mini-batch's Jacobian: + + .. math:: + + \lambda_{t+1} = \operatorname{softmax}\!\bigl( + \lambda_t - \gamma \cdot (G \lambda_t + \rho \lambda_t) + \bigr) + + The paper specifies hard simplex projection :math:`\Pi_\Delta`; we follow the `official + LibMTL implementation `_ and use + :func:`torch.softmax` as the projection step. + + The state :math:`\lambda_{t-1}` is initialised lazily to the uniform vector + :math:`[1/m, \ldots, 1/m]` on the first forward call once :math:`m` is known, and is reset + automatically when :math:`m`, ``dtype`` or ``device`` of the input Gramian changes. Use + :meth:`reset` to manually restart the smoothing from uniform weights. + + .. warning:: + MoDo's convergence guarantees rely on **double sampling**: the Gramian passed to this + weighting must come from a mini-batch that is independent of the one used for the + subsequent parameter update. See the usage example below. + + :param gamma: Learning rate of the task-weight update. Must be positive. + :param rho: Non-negative :math:`\ell_2` regularisation coefficient. + + .. admonition:: Example + + Train a model using MoDo with two independent mini-batches per step. The first batch + drives the :math:`\lambda` update via the Gramian; the second batch drives the parameter + update via the usual backward pass. + + .. code-block:: python + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import MoDoWeighting + from torchjd.autogram import Engine + + model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) + optimizer = SGD(model.parameters()) + criterion = MSELoss(reduction="none") + weighting = MoDoWeighting(gamma=0.1, rho=0.0) + engine = Engine(model, batch_dim=0) + + # loader_1 and loader_2 must yield independent draws from the same distribution. + for batch_1, batch_2 in zip(loader_1, loader_2): + input_1, target_1 = batch_1 + input_2, target_2 = batch_2 + + # Step 1: Gramian from batch 1 drives the lambda update. + losses_1 = criterion(model(input_1).squeeze(dim=1), target_1) + gramian = engine.compute_gramian(losses_1) + weights = weighting(gramian) + + # Step 2: backward on batch 2 with those weights drives the parameter update. + losses_2 = criterion(model(input_2).squeeze(dim=1), target_2) + losses_2.backward(weights) + optimizer.step() + optimizer.zero_grad() + """ + + def __init__(self, gamma: float = 0.1, rho: float = 0.0) -> None: + super().__init__() + self.gamma = gamma + self.rho = rho + self._lambda: Tensor | None = None + self._state_key: tuple[int, torch.dtype, torch.device] | None = None + + @property + def gamma(self) -> float: + return self._gamma + + @gamma.setter + def gamma(self, value: float) -> None: + if value <= 0.0: + raise ValueError(f"Attribute `gamma` must be positive. Found gamma={value!r}.") + self._gamma = value + + @property + def rho(self) -> float: + return self._rho + + @rho.setter + def rho(self, value: float) -> None: + if value < 0.0: + raise ValueError(f"Attribute `rho` must be non-negative. Found rho={value!r}.") + self._rho = value + + def reset(self) -> None: + """Clears the stored task weights so the next forward starts from uniform.""" + + self._lambda = None + self._state_key = None + + def forward(self, gramian: PSDMatrix, /) -> Tensor: + m = gramian.shape[0] + if m == 0: + return gramian.new_empty((0,)) + + self._ensure_state(gramian) + lambd = cast(Tensor, self._lambda) + + with torch.no_grad(): + grad = gramian @ lambd + self._rho * lambd + lambd = torch.softmax(lambd - self._gamma * grad, dim=-1) + + self._lambda = lambd + return lambd + + def _ensure_state(self, gramian: PSDMatrix) -> None: + key = (gramian.shape[0], gramian.dtype, gramian.device) + if self._state_key == key and self._lambda is not None: + return + self._lambda = gramian.new_full((gramian.shape[0],), 1.0 / gramian.shape[0]) + self._state_key = key + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(gamma={self.gamma!r}, rho={self.rho!r})" diff --git a/tests/unit/aggregation/test_modo.py b/tests/unit/aggregation/test_modo.py new file mode 100644 index 00000000..af80fbab --- /dev/null +++ b/tests/unit/aggregation/test_modo.py @@ -0,0 +1,153 @@ +import torch +from pytest import mark, raises +from torch import Tensor +from torch.testing import assert_close +from utils.tensors import randn_, tensor_ + +from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator +from torchjd.aggregation._modo import MoDoWeighting + +from ._asserts import assert_expected_structure +from ._inputs import scaled_matrices, typical_matrices + +gramian_pairs = [ + (GramianWeightedAggregator(MoDoWeighting()), m) for m in typical_matrices + scaled_matrices +] + + +def test_representations() -> None: + W = MoDoWeighting(gamma=0.1, rho=0.05) + assert repr(W) == "MoDoWeighting(gamma=0.1, rho=0.05)" + + +@mark.parametrize(["aggregator", "matrix"], gramian_pairs) +def test_expected_structure_gramian_weighting( + aggregator: GramianWeightedAggregator, matrix: Tensor +) -> None: + assert_expected_structure(aggregator, matrix) + + +def test_reset_restores_first_step_behavior() -> None: + J = randn_((3, 8)) + G = J @ J.T + W = MoDoWeighting(gamma=0.1) + first = W(G) + W(G) + W.reset() + assert_close(first, W(G)) + + +def test_gamma_setter_accepts_valid() -> None: + W = MoDoWeighting() + W.gamma = 0.01 + assert W.gamma == 0.01 + W.gamma = 0.1 + assert W.gamma == 0.1 + W.gamma = 1.0 + assert W.gamma == 1.0 + + +def test_gamma_setter_rejects_non_positive() -> None: + W = MoDoWeighting() + with raises(ValueError, match="gamma"): + W.gamma = 0.0 + with raises(ValueError, match="gamma"): + W.gamma = -0.1 + + +def test_rho_setter_accepts_valid() -> None: + W = MoDoWeighting() + W.rho = 0.0 + assert W.rho == 0.0 + W.rho = 0.1 + assert W.rho == 0.1 + + +def test_rho_setter_rejects_negative() -> None: + W = MoDoWeighting() + with raises(ValueError, match="rho"): + W.rho = -0.1 + + +def test_output_lies_on_simplex() -> None: + """The softmax projection ensures the weights sum to 1 and are non-negative.""" + + J = randn_((4, 10)) + G = J @ J.T + W = MoDoWeighting(gamma=0.1, rho=0.05) + weights = W(G) + assert weights.shape == (4,) + assert (weights >= 0).all() + assert_close(weights.sum(), tensor_(1.0)) + + +def test_small_gamma_stays_near_uniform() -> None: + """With a tiny gamma, one step barely moves lambda from the uniform initialisation.""" + + J = randn_((3, 8)) + G = J @ J.T + m = J.shape[0] + W = MoDoWeighting(gamma=1e-8) + uniform = tensor_([1.0 / m] * m) + assert_close(W(G), uniform, atol=1e-6, rtol=1e-6) + + +def test_update_recurrence() -> None: + """Verify one step of the softmax-projected gradient update by hand.""" + + gamma = 0.1 + rho = 0.05 + J = randn_((3, 8)) + G = J @ J.T + m = J.shape[0] + + W = MoDoWeighting(gamma=gamma, rho=rho) + lambda_0 = tensor_([1.0 / m] * m) + grad = G @ lambda_0 + rho * lambda_0 + expected = torch.softmax(lambda_0 - gamma * grad, dim=-1) + + assert_close(W(G), expected) + + +def test_two_consecutive_steps() -> None: + """Verify two consecutive steps of the softmax-projected gradient update.""" + + gamma = 0.1 + rho = 0.0 + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + G1 = J1 @ J1.T + G2 = J2 @ J2.T + m = J1.shape[0] + + W = MoDoWeighting(gamma=gamma, rho=rho) + + lambda_0 = tensor_([1.0 / m] * m) + grad_1 = G1 @ lambda_0 + rho * lambda_0 + lambda_1 = torch.softmax(lambda_0 - gamma * grad_1, dim=-1) + + grad_2 = G2 @ lambda_1 + rho * lambda_1 + lambda_2 = torch.softmax(lambda_1 - gamma * grad_2, dim=-1) + + assert_close(W(G1), lambda_1) + assert_close(W(G2), lambda_2) + + +def test_changing_m_auto_resets() -> None: + """When the number of objectives changes, the state is re-initialised to uniform.""" + + W = MoDoWeighting(gamma=0.1) + W(randn_((3, 8)) @ randn_((3, 8)).T) + # After a state-resetting call with m=2, the first output should equal the uniform step's output. + fresh = MoDoWeighting(gamma=0.1) + J = randn_((2, 8)) + G = J @ J.T + assert_close(W(G), fresh(G)) + + +def test_zero_rows() -> None: + """A (0, 0) Gramian yields an empty weight vector.""" + + W = MoDoWeighting() + weights = W(tensor_([]).reshape(0, 0)) + assert weights.shape == (0,) From b416fbadffeb90f33761c32faa0598a4309e5731 Mon Sep 17 00:00:00 2001 From: Khush Date: Fri, 29 May 2026 10:29:45 -0400 Subject: [PATCH 2/5] refactor(aggregation): Address review feedback on MoDoWeighting --- src/torchjd/aggregation/_modo.py | 31 ++++++++++++----------------- tests/unit/aggregation/test_modo.py | 8 -------- 2 files changed, 13 insertions(+), 26 deletions(-) diff --git a/src/torchjd/aggregation/_modo.py b/src/torchjd/aggregation/_modo.py index d24629b3..d219b44e 100644 --- a/src/torchjd/aggregation/_modo.py +++ b/src/torchjd/aggregation/_modo.py @@ -5,23 +5,22 @@ import torch from torch import Tensor -from torchjd.aggregation._mixins import Stateful +from torchjd.aggregation._mixins import Stateful, _NonDifferentiable from torchjd.linalg import PSDMatrix from ._weighting_bases import _GramianWeighting -class MoDoWeighting(_GramianWeighting, Stateful): +class MoDoWeighting(_GramianWeighting, Stateful, _NonDifferentiable): r""" :class:`~torchjd.aggregation._mixins.Stateful` - :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] implementing the - task-weight update from `Three-Way Trade-Off in Multi-Objective Learning: Optimization, - Generalization and Conflict-Avoidance `_ - (JMLR 2024), commonly referred to as MoDo (Multi-Objective gradient with Double sampling). + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] from `Three-Way + Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance + `_ (JMLR 2024), commonly referred + to as MoDo (Multi-Objective gradient with Double sampling). - At each call, the weights :math:`\lambda` are updated by a projected gradient step on - :math:`\lambda^\top G \lambda + \rho \|\lambda\|^2` where :math:`G = G_1 G_1^\top` is the - Gramian of the first mini-batch's Jacobian: + Given a Gramian :math:`G`, the weights :math:`\lambda` are updated at each call by a + softmax-projected gradient step: .. math:: @@ -36,12 +35,13 @@ class MoDoWeighting(_GramianWeighting, Stateful): The state :math:`\lambda_{t-1}` is initialised lazily to the uniform vector :math:`[1/m, \ldots, 1/m]` on the first forward call once :math:`m` is known, and is reset automatically when :math:`m`, ``dtype`` or ``device`` of the input Gramian changes. Use - :meth:`reset` to manually restart the smoothing from uniform weights. + :meth:`reset` to manually restart from uniform weights. .. warning:: MoDo's convergence guarantees rely on **double sampling**: the Gramian passed to this weighting must come from a mini-batch that is independent of the one used for the - subsequent parameter update. See the usage example below. + subsequent parameter update. The Gramian can be computed efficiently from a batch of + losses using the :class:`~torchjd.autogram.Engine`. See the usage example below. :param gamma: Learning rate of the task-weight update. Must be positive. :param rho: Non-negative :math:`\ell_2` regularisation coefficient. @@ -118,16 +118,11 @@ def reset(self) -> None: self._state_key = None def forward(self, gramian: PSDMatrix, /) -> Tensor: - m = gramian.shape[0] - if m == 0: - return gramian.new_empty((0,)) - self._ensure_state(gramian) lambd = cast(Tensor, self._lambda) - with torch.no_grad(): - grad = gramian @ lambd + self._rho * lambd - lambd = torch.softmax(lambd - self._gamma * grad, dim=-1) + grad = gramian @ lambd + self._rho * lambd + lambd = torch.softmax(lambd - self._gamma * grad, dim=-1) self._lambda = lambd return lambd diff --git a/tests/unit/aggregation/test_modo.py b/tests/unit/aggregation/test_modo.py index af80fbab..9b9193be 100644 --- a/tests/unit/aggregation/test_modo.py +++ b/tests/unit/aggregation/test_modo.py @@ -143,11 +143,3 @@ def test_changing_m_auto_resets() -> None: J = randn_((2, 8)) G = J @ J.T assert_close(W(G), fresh(G)) - - -def test_zero_rows() -> None: - """A (0, 0) Gramian yields an empty weight vector.""" - - W = MoDoWeighting() - weights = W(tensor_([]).reshape(0, 0)) - assert weights.shape == (0,) From 2c6188a8dadd28d17cb988bcf19afe7805165232 Mon Sep 17 00:00:00 2001 From: Khush Date: Sun, 31 May 2026 18:13:08 -0400 Subject: [PATCH 3/5] refactor(aggregation): Address review feedback on MoDoWeighting --- CHANGELOG.md | 4 +- src/torchjd/aggregation/_modo.py | 114 +++++++++++++++++----------- tests/unit/aggregation/test_modo.py | 48 +++++++----- 3 files changed, 101 insertions(+), 65 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 225e65d5..926e641b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,9 +10,7 @@ changelog does not include internal changes that do not affect the user. ### Added -- Added `MoDoWeighting` from [Three-Way Trade-Off in Multi-Objective Learning: Optimization,Generalization and Conflict-Avoidance](https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf) (JMLR 2024). It is a stateful `Weighting` that maintains task weights across calls via a - softmax-projected gradient step on the Gramian, intended to be composed with `autogram.Engine` - in a two-batch training loop. +- Added `MoDoWeighting` from [Three-Way Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance](https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf) (JMLR 2024). It is a stateful `Weighting` that maintains task weights across calls via a softmax-projected gradient step on a cross-batch matrix `G = J_1 @ J_2.T`, computed from two independent mini-batches using `autojac.jac`. - Added `GeometricMean` (also known as GLS) studied in [MultiNet++: Multi-Stream Feature Aggregation and Geometric Loss Strategy for Multi-Task Learning](https://openaccess.thecvf.com/content_CVPRW_2019/papers/WAD/Chennupati_MultiNet_Multi-Stream_Feature_Aggregation_and_Geometric_Loss_Strategy_for_Multi-Task_CVPRW_2019_paper.pdf), diff --git a/src/torchjd/aggregation/_modo.py b/src/torchjd/aggregation/_modo.py index d219b44e..f38e2ad9 100644 --- a/src/torchjd/aggregation/_modo.py +++ b/src/torchjd/aggregation/_modo.py @@ -6,51 +6,29 @@ from torch import Tensor from torchjd.aggregation._mixins import Stateful, _NonDifferentiable -from torchjd.linalg import PSDMatrix +from torchjd.linalg import Matrix -from ._weighting_bases import _GramianWeighting +from ._weighting_bases import _MatrixWeighting -class MoDoWeighting(_GramianWeighting, Stateful, _NonDifferentiable): +class MoDoWeighting(_MatrixWeighting, Stateful, _NonDifferentiable): r""" :class:`~torchjd.aggregation._mixins.Stateful` - :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] from `Three-Way + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] from `Three-Way Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance - `_ (JMLR 2024), commonly referred - to as MoDo (Multi-Objective gradient with Double sampling). - - Given a Gramian :math:`G`, the weights :math:`\lambda` are updated at each call by a - softmax-projected gradient step: - - .. math:: - - \lambda_{t+1} = \operatorname{softmax}\!\bigl( - \lambda_t - \gamma \cdot (G \lambda_t + \rho \lambda_t) - \bigr) - - The paper specifies hard simplex projection :math:`\Pi_\Delta`; we follow the `official - LibMTL implementation `_ and use - :func:`torch.softmax` as the projection step. - - The state :math:`\lambda_{t-1}` is initialised lazily to the uniform vector - :math:`[1/m, \ldots, 1/m]` on the first forward call once :math:`m` is known, and is reset - automatically when :math:`m`, ``dtype`` or ``device`` of the input Gramian changes. Use - :meth:`reset` to manually restart from uniform weights. + `_ (JMLR 2024). .. warning:: - MoDo's convergence guarantees rely on **double sampling**: the Gramian passed to this - weighting must come from a mini-batch that is independent of the one used for the - subsequent parameter update. The Gramian can be computed efficiently from a batch of - losses using the :class:`~torchjd.autogram.Engine`. See the usage example below. + The input matrix must be :math:`G = J_1 J_2^\top`, computed from two **independent** + mini-batches via :func:`torchjd.autojac.jac`. Using a single-batch Gramian + (:math:`J_1 J_1^\top`) breaks the convergence guarantee. See the usage examples below. :param gamma: Learning rate of the task-weight update. Must be positive. :param rho: Non-negative :math:`\ell_2` regularisation coefficient. - .. admonition:: Example + .. admonition:: Example (two batches per step) - Train a model using MoDo with two independent mini-batches per step. The first batch - drives the :math:`\lambda` update via the Gramian; the second batch drives the parameter - update via the usual backward pass. + The following example reproduces basic MoDo using two independent mini-batches per step. .. code-block:: python @@ -59,29 +37,75 @@ class MoDoWeighting(_GramianWeighting, Stateful, _NonDifferentiable): from torch.optim import SGD from torchjd.aggregation import MoDoWeighting - from torchjd.autogram import Engine + from torchjd.autojac import jac model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) optimizer = SGD(model.parameters()) criterion = MSELoss(reduction="none") weighting = MoDoWeighting(gamma=0.1, rho=0.0) - engine = Engine(model, batch_dim=0) + params = list(model.parameters()) - # loader_1 and loader_2 must yield independent draws from the same distribution. + # loader_1 and loader_2 must yield independent draws of the same size. for batch_1, batch_2 in zip(loader_1, loader_2): input_1, target_1 = batch_1 input_2, target_2 = batch_2 - # Step 1: Gramian from batch 1 drives the lambda update. losses_1 = criterion(model(input_1).squeeze(dim=1), target_1) - gramian = engine.compute_gramian(losses_1) - weights = weighting(gramian) + jacs_1 = jac(losses_1, params) + J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1) - # Step 2: backward on batch 2 with those weights drives the parameter update. + # retain_graph=True keeps the graph for the backward step below. losses_2 = criterion(model(input_2).squeeze(dim=1), target_2) + jacs_2 = jac(losses_2, params, retain_graph=True) + J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1) + + G = J_1 @ J_2.T + weights = weighting(G) + losses_2.backward(weights) optimizer.step() optimizer.zero_grad() + + .. admonition:: Example (three batches per step) + + The following example reproduces basic MoDo using three independent mini-batches per step, + keeping the :math:`\lambda` update and the parameter update on separate draws. + + .. code-block:: python + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import MoDoWeighting + from torchjd.autojac import jac + + model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) + optimizer = SGD(model.parameters()) + criterion = MSELoss(reduction="none") + weighting = MoDoWeighting(gamma=0.1, rho=0.0) + params = list(model.parameters()) + + for batch_1, batch_2, batch_3 in zip(loader_1, loader_2, loader_3): + input_1, target_1 = batch_1 + input_2, target_2 = batch_2 + input_3, target_3 = batch_3 + + losses_1 = criterion(model(input_1).squeeze(dim=1), target_1) + jacs_1 = jac(losses_1, params) + J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1) + + losses_2 = criterion(model(input_2).squeeze(dim=1), target_2) + jacs_2 = jac(losses_2, params) + J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1) + + G = J_1 @ J_2.T + weights = weighting(G) + + losses_3 = criterion(model(input_3).squeeze(dim=1), target_3) + losses_3.backward(weights) + optimizer.step() + optimizer.zero_grad() """ def __init__(self, gamma: float = 0.1, rho: float = 0.0) -> None: @@ -117,21 +141,21 @@ def reset(self) -> None: self._lambda = None self._state_key = None - def forward(self, gramian: PSDMatrix, /) -> Tensor: - self._ensure_state(gramian) + def forward(self, matrix: Matrix, /) -> Tensor: + self._ensure_state(matrix) lambd = cast(Tensor, self._lambda) - grad = gramian @ lambd + self._rho * lambd + grad = matrix @ lambd + self._rho * lambd lambd = torch.softmax(lambd - self._gamma * grad, dim=-1) self._lambda = lambd return lambd - def _ensure_state(self, gramian: PSDMatrix) -> None: - key = (gramian.shape[0], gramian.dtype, gramian.device) + def _ensure_state(self, matrix: Matrix) -> None: + key = (matrix.shape[0], matrix.dtype, matrix.device) if self._state_key == key and self._lambda is not None: return - self._lambda = gramian.new_full((gramian.shape[0],), 1.0 / gramian.shape[0]) + self._lambda = matrix.new_full((matrix.shape[0],), 1.0 / matrix.shape[0]) self._state_key = key def __repr__(self) -> str: diff --git a/tests/unit/aggregation/test_modo.py b/tests/unit/aggregation/test_modo.py index 9b9193be..6203c99d 100644 --- a/tests/unit/aggregation/test_modo.py +++ b/tests/unit/aggregation/test_modo.py @@ -1,32 +1,16 @@ import torch -from pytest import mark, raises -from torch import Tensor +from pytest import raises from torch.testing import assert_close from utils.tensors import randn_, tensor_ -from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator from torchjd.aggregation._modo import MoDoWeighting -from ._asserts import assert_expected_structure -from ._inputs import scaled_matrices, typical_matrices - -gramian_pairs = [ - (GramianWeightedAggregator(MoDoWeighting()), m) for m in typical_matrices + scaled_matrices -] - def test_representations() -> None: W = MoDoWeighting(gamma=0.1, rho=0.05) assert repr(W) == "MoDoWeighting(gamma=0.1, rho=0.05)" -@mark.parametrize(["aggregator", "matrix"], gramian_pairs) -def test_expected_structure_gramian_weighting( - aggregator: GramianWeightedAggregator, matrix: Tensor -) -> None: - assert_expected_structure(aggregator, matrix) - - def test_reset_restores_first_step_behavior() -> None: J = randn_((3, 8)) G = J @ J.T @@ -143,3 +127,33 @@ def test_changing_m_auto_resets() -> None: J = randn_((2, 8)) G = J @ J.T assert_close(W(G), fresh(G)) + + +def test_non_differentiable() -> None: + """The _NonDifferentiable mixin must prevent autograd graph construction.""" + + G = randn_((3, 8)) @ randn_((3, 8)).T + G.requires_grad_(True) + W = MoDoWeighting() + weights = W(G) + assert not weights.requires_grad + + +def test_non_symmetric_input() -> None: + """MoDoWeighting must accept and correctly process a non-symmetric cross-batch matrix.""" + + gamma = 0.1 + rho = 0.05 + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + G = J1 @ J2.T # not symmetric, not PSD in general + m = J1.shape[0] + + W = MoDoWeighting(gamma=gamma, rho=rho) + lambda_0 = tensor_([1.0 / m] * m) + grad = G @ lambda_0 + rho * lambda_0 + expected = torch.softmax(lambda_0 - gamma * grad, dim=-1) + + assert_close(W(G), expected) + assert W(G).shape == (m,) + assert (W(G) >= 0).all() From 63cf75373e17e5a2b04b1e4966a963029ba0e839 Mon Sep 17 00:00:00 2001 From: Khush Date: Tue, 2 Jun 2026 19:58:36 -0400 Subject: [PATCH 4/5] refactor(aggregation): Address review feedback on MoDoWeighting --- src/torchjd/aggregation/_modo.py | 43 +++++++++++++++++++---------- tests/unit/aggregation/test_modo.py | 39 ++++++++++++-------------- 2 files changed, 45 insertions(+), 37 deletions(-) diff --git a/src/torchjd/aggregation/_modo.py b/src/torchjd/aggregation/_modo.py index f38e2ad9..18d3dff2 100644 --- a/src/torchjd/aggregation/_modo.py +++ b/src/torchjd/aggregation/_modo.py @@ -29,8 +29,10 @@ class MoDoWeighting(_MatrixWeighting, Stateful, _NonDifferentiable): .. admonition:: Example (two batches per step) The following example reproduces basic MoDo using two independent mini-batches per step. + This matches MoDo as described in the paper, and the behavior of the official + implementation when ``three_grads`` is ``False``. - .. code-block:: python + .. testcode:: import torch from torch.nn import Linear, MSELoss, ReLU, Sequential @@ -39,22 +41,26 @@ class MoDoWeighting(_MatrixWeighting, Stateful, _NonDifferentiable): from torchjd.aggregation import MoDoWeighting from torchjd.autojac import jac + # Generate data (8 batches of 16 examples of dim 5) for the sake of the example. + inputs = torch.randn(8, 16, 5) + targets = torch.randn(8, 16) + model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) optimizer = SGD(model.parameters()) criterion = MSELoss(reduction="none") weighting = MoDoWeighting(gamma=0.1, rho=0.0) params = list(model.parameters()) - # loader_1 and loader_2 must yield independent draws of the same size. - for batch_1, batch_2 in zip(loader_1, loader_2): - input_1, target_1 = batch_1 - input_2, target_2 = batch_2 + # Consume two consecutive (independent) batches per step. + for i in range(len(inputs) // 2): + input_1, input_2 = inputs[2 * i], inputs[2 * i + 1] + target_1, target_2 = targets[2 * i], targets[2 * i + 1] + # retain_graph=True so both graphs survive for the backward step below. losses_1 = criterion(model(input_1).squeeze(dim=1), target_1) - jacs_1 = jac(losses_1, params) + jacs_1 = jac(losses_1, params, retain_graph=True) J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1) - # retain_graph=True keeps the graph for the backward step below. losses_2 = criterion(model(input_2).squeeze(dim=1), target_2) jacs_2 = jac(losses_2, params, retain_graph=True) J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1) @@ -62,16 +68,19 @@ class MoDoWeighting(_MatrixWeighting, Stateful, _NonDifferentiable): G = J_1 @ J_2.T weights = weighting(G) - losses_2.backward(weights) + # Equation 2.9b: the parameter update uses the mean of both batches' losses. + losses = (losses_1 + losses_2) / 2.0 + losses.backward(weights) optimizer.step() optimizer.zero_grad() .. admonition:: Example (three batches per step) The following example reproduces basic MoDo using three independent mini-batches per step, - keeping the :math:`\lambda` update and the parameter update on separate draws. + keeping the :math:`\lambda` update and the parameter update on separate draws. This matches + the behavior of LibMTL and of the official implementation when ``three_grads`` is ``True``. - .. code-block:: python + .. testcode:: import torch from torch.nn import Linear, MSELoss, ReLU, Sequential @@ -80,16 +89,20 @@ class MoDoWeighting(_MatrixWeighting, Stateful, _NonDifferentiable): from torchjd.aggregation import MoDoWeighting from torchjd.autojac import jac + # Generate data (9 batches of 16 examples of dim 5) for the sake of the example. + inputs = torch.randn(9, 16, 5) + targets = torch.randn(9, 16) + model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) optimizer = SGD(model.parameters()) criterion = MSELoss(reduction="none") weighting = MoDoWeighting(gamma=0.1, rho=0.0) params = list(model.parameters()) - for batch_1, batch_2, batch_3 in zip(loader_1, loader_2, loader_3): - input_1, target_1 = batch_1 - input_2, target_2 = batch_2 - input_3, target_3 = batch_3 + # Consume three consecutive (independent) batches per step. + for i in range(len(inputs) // 3): + input_1, input_2, input_3 = inputs[3 * i], inputs[3 * i + 1], inputs[3 * i + 2] + target_1, target_2, target_3 = targets[3 * i], targets[3 * i + 1], targets[3 * i + 2] losses_1 = criterion(model(input_1).squeeze(dim=1), target_1) jacs_1 = jac(losses_1, params) @@ -108,7 +121,7 @@ class MoDoWeighting(_MatrixWeighting, Stateful, _NonDifferentiable): optimizer.zero_grad() """ - def __init__(self, gamma: float = 0.1, rho: float = 0.0) -> None: + def __init__(self, gamma: float = 0.1, rho: float = 0.1) -> None: super().__init__() self.gamma = gamma self.rho = rho diff --git a/tests/unit/aggregation/test_modo.py b/tests/unit/aggregation/test_modo.py index 6203c99d..654ac04f 100644 --- a/tests/unit/aggregation/test_modo.py +++ b/tests/unit/aggregation/test_modo.py @@ -12,8 +12,9 @@ def test_representations() -> None: def test_reset_restores_first_step_behavior() -> None: - J = randn_((3, 8)) - G = J @ J.T + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + G = J1 @ J2.T W = MoDoWeighting(gamma=0.1) first = W(G) W(G) @@ -56,8 +57,9 @@ def test_rho_setter_rejects_negative() -> None: def test_output_lies_on_simplex() -> None: """The softmax projection ensures the weights sum to 1 and are non-negative.""" - J = randn_((4, 10)) - G = J @ J.T + J1 = randn_((4, 10)) + J2 = randn_((4, 10)) + G = J1 @ J2.T W = MoDoWeighting(gamma=0.1, rho=0.05) weights = W(G) assert weights.shape == (4,) @@ -65,25 +67,15 @@ def test_output_lies_on_simplex() -> None: assert_close(weights.sum(), tensor_(1.0)) -def test_small_gamma_stays_near_uniform() -> None: - """With a tiny gamma, one step barely moves lambda from the uniform initialisation.""" - - J = randn_((3, 8)) - G = J @ J.T - m = J.shape[0] - W = MoDoWeighting(gamma=1e-8) - uniform = tensor_([1.0 / m] * m) - assert_close(W(G), uniform, atol=1e-6, rtol=1e-6) - - def test_update_recurrence() -> None: """Verify one step of the softmax-projected gradient update by hand.""" gamma = 0.1 rho = 0.05 - J = randn_((3, 8)) - G = J @ J.T - m = J.shape[0] + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + G = J1 @ J2.T + m = J1.shape[0] W = MoDoWeighting(gamma=gamma, rho=rho) lambda_0 = tensor_([1.0 / m] * m) @@ -100,8 +92,10 @@ def test_two_consecutive_steps() -> None: rho = 0.0 J1 = randn_((3, 8)) J2 = randn_((3, 8)) - G1 = J1 @ J1.T - G2 = J2 @ J2.T + J3 = randn_((3, 8)) + J4 = randn_((3, 8)) + G1 = J1 @ J2.T + G2 = J3 @ J4.T m = J1.shape[0] W = MoDoWeighting(gamma=gamma, rho=rho) @@ -124,8 +118,9 @@ def test_changing_m_auto_resets() -> None: W(randn_((3, 8)) @ randn_((3, 8)).T) # After a state-resetting call with m=2, the first output should equal the uniform step's output. fresh = MoDoWeighting(gamma=0.1) - J = randn_((2, 8)) - G = J @ J.T + J1 = randn_((2, 8)) + J2 = randn_((2, 8)) + G = J1 @ J2.T assert_close(W(G), fresh(G)) From 494553da54ae2fbcb764e89814d93587ff414428 Mon Sep 17 00:00:00 2001 From: Khush Date: Wed, 3 Jun 2026 01:12:03 -0400 Subject: [PATCH 5/5] refactor(aggregation): Use simplex projection from official MoDo implementation --- CHANGELOG.md | 2 +- NOTICES | 28 ++++++++++++++++++++ src/torchjd/aggregation/_modo.py | 24 ++++++++++++++++- tests/unit/aggregation/test_modo.py | 41 ++++++++++++++++++++++++++--- 4 files changed, 89 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 926e641b..7bd9921e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ changelog does not include internal changes that do not affect the user. ### Added -- Added `MoDoWeighting` from [Three-Way Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance](https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf) (JMLR 2024). It is a stateful `Weighting` that maintains task weights across calls via a softmax-projected gradient step on a cross-batch matrix `G = J_1 @ J_2.T`, computed from two independent mini-batches using `autojac.jac`. +- Added `MoDoWeighting` from [Three-Way Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance](https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf) (JMLR 2024). It is a stateful `Weighting` that maintains task weights across calls via a simplex-projected gradient step on a cross-batch matrix `G = J_1 @ J_2.T`, computed from two independent mini-batches using `autojac.jac`. - Added `GeometricMean` (also known as GLS) studied in [MultiNet++: Multi-Stream Feature Aggregation and Geometric Loss Strategy for Multi-Task Learning](https://openaccess.thecvf.com/content_CVPRW_2019/papers/WAD/Chennupati_MultiNet_Multi-Stream_Feature_Aggregation_and_Geometric_Loss_Strategy_for_Multi-Task_CVPRW_2019_paper.pdf), diff --git a/NOTICES b/NOTICES index 18a1d601..07c3e851 100644 --- a/NOTICES +++ b/NOTICES @@ -112,3 +112,31 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +------------------------------------------------------------------------------- + +Project: MoDo +Source: https://github.com/heshandevaka/Trade-Off-MOL/blob/main/LibMTL/LibMTL/weighting/MoDo.py +Used in: src/torchjd/aggregation/_modo.py + +MIT License + +Copyright (c) 2023 Heshan Fernando + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/torchjd/aggregation/_modo.py b/src/torchjd/aggregation/_modo.py index 18d3dff2..7c1a2e83 100644 --- a/src/torchjd/aggregation/_modo.py +++ b/src/torchjd/aggregation/_modo.py @@ -1,3 +1,5 @@ +# Partly adapted from https://github.com/heshandevaka/Trade-Off-MOL — MIT License, Copyright (c) 2023 Heshan Fernando. +# See NOTICES for the full license text. from __future__ import annotations from typing import cast @@ -26,6 +28,10 @@ class MoDoWeighting(_MatrixWeighting, Stateful, _NonDifferentiable): :param gamma: Learning rate of the task-weight update. Must be positive. :param rho: Non-negative :math:`\ell_2` regularisation coefficient. + .. note:: + The Euclidean projection onto the simplex used in the :math:`\lambda` update is adapted from + the `official implementation `_. + .. admonition:: Example (two batches per step) The following example reproduces basic MoDo using two independent mini-batches per step. @@ -159,11 +165,27 @@ def forward(self, matrix: Matrix, /) -> Tensor: lambd = cast(Tensor, self._lambda) grad = matrix @ lambd + self._rho * lambd - lambd = torch.softmax(lambd - self._gamma * grad, dim=-1) + lambd = self._projection2simplex(lambd - self._gamma * grad) self._lambda = lambd return lambd + @staticmethod + def _projection2simplex(y: Tensor) -> Tensor: + """Euclidean projection of ``y`` onto the probability simplex.""" + + m = len(y) + sorted_y = torch.sort(y, descending=True)[0] + tmpsum = y.new_zeros(()) + tmax_f = (torch.sum(y) - 1.0) / m + for i in range(m - 1): + tmpsum = tmpsum + sorted_y[i] + tmax = (tmpsum - 1.0) / (i + 1.0) + if tmax > sorted_y[i + 1]: + tmax_f = tmax + break + return torch.max(y - tmax_f, y.new_zeros(m)) + def _ensure_state(self, matrix: Matrix) -> None: key = (matrix.shape[0], matrix.dtype, matrix.device) if self._state_key == key and self._lambda is not None: diff --git a/tests/unit/aggregation/test_modo.py b/tests/unit/aggregation/test_modo.py index 654ac04f..f169ea86 100644 --- a/tests/unit/aggregation/test_modo.py +++ b/tests/unit/aggregation/test_modo.py @@ -1,11 +1,29 @@ import torch from pytest import raises +from torch import Tensor from torch.testing import assert_close from utils.tensors import randn_, tensor_ from torchjd.aggregation._modo import MoDoWeighting +def _project_to_simplex(y: Tensor) -> Tensor: + """Reference Euclidean projection onto the probability simplex, used to derive expected + values independently of the implementation.""" + + m = len(y) + sorted_y = torch.sort(y, descending=True)[0] + tmpsum = y.new_zeros(()) + tmax_f = (torch.sum(y) - 1.0) / m + for i in range(m - 1): + tmpsum = tmpsum + sorted_y[i] + tmax = (tmpsum - 1.0) / (i + 1.0) + if tmax > sorted_y[i + 1]: + tmax_f = tmax + break + return torch.max(y - tmax_f, y.new_zeros(m)) + + def test_representations() -> None: W = MoDoWeighting(gamma=0.1, rho=0.05) assert repr(W) == "MoDoWeighting(gamma=0.1, rho=0.05)" @@ -80,7 +98,7 @@ def test_update_recurrence() -> None: W = MoDoWeighting(gamma=gamma, rho=rho) lambda_0 = tensor_([1.0 / m] * m) grad = G @ lambda_0 + rho * lambda_0 - expected = torch.softmax(lambda_0 - gamma * grad, dim=-1) + expected = _project_to_simplex(lambda_0 - gamma * grad) assert_close(W(G), expected) @@ -102,10 +120,10 @@ def test_two_consecutive_steps() -> None: lambda_0 = tensor_([1.0 / m] * m) grad_1 = G1 @ lambda_0 + rho * lambda_0 - lambda_1 = torch.softmax(lambda_0 - gamma * grad_1, dim=-1) + lambda_1 = _project_to_simplex(lambda_0 - gamma * grad_1) grad_2 = G2 @ lambda_1 + rho * lambda_1 - lambda_2 = torch.softmax(lambda_1 - gamma * grad_2, dim=-1) + lambda_2 = _project_to_simplex(lambda_1 - gamma * grad_2) assert_close(W(G1), lambda_1) assert_close(W(G2), lambda_2) @@ -147,8 +165,23 @@ def test_non_symmetric_input() -> None: W = MoDoWeighting(gamma=gamma, rho=rho) lambda_0 = tensor_([1.0 / m] * m) grad = G @ lambda_0 + rho * lambda_0 - expected = torch.softmax(lambda_0 - gamma * grad, dim=-1) + expected = _project_to_simplex(lambda_0 - gamma * grad) assert_close(W(G), expected) assert W(G).shape == (m,) assert (W(G) >= 0).all() + + +def test_projection2simplex_known_values() -> None: + """The simplex projection matches hand-computed Euclidean projections.""" + + # Already-positive input: the deficit (1 - sum) is spread equally, no clamping. + assert_close( + MoDoWeighting._projection2simplex(tensor_([0.5, 0.1, 0.1])), + tensor_([0.6, 0.2, 0.2]), + ) + # Input with a negative entry: it gets clamped to zero. + assert_close( + MoDoWeighting._projection2simplex(tensor_([1.0, 0.0, -0.5])), + tensor_([1.0, 0.0, 0.0]), + )