diff --git a/CHANGELOG.md b/CHANGELOG.md index ab5a536d..7bd9921e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ changelog does not include internal changes that do not affect the user. ### Added +- Added `MoDoWeighting` from [Three-Way Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance](https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf) (JMLR 2024). It is a stateful `Weighting` that maintains task weights across calls via a simplex-projected gradient step on a cross-batch matrix `G = J_1 @ J_2.T`, computed from two independent mini-batches using `autojac.jac`. - Added `GeometricMean` (also known as GLS) studied in [MultiNet++: Multi-Stream Feature Aggregation and Geometric Loss Strategy for Multi-Task Learning](https://openaccess.thecvf.com/content_CVPRW_2019/papers/WAD/Chennupati_MultiNet_Multi-Stream_Feature_Aggregation_and_Geometric_Loss_Strategy_for_Multi-Task_CVPRW_2019_paper.pdf), diff --git a/NOTICES b/NOTICES index 18a1d601..07c3e851 100644 --- a/NOTICES +++ b/NOTICES @@ -112,3 +112,31 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +------------------------------------------------------------------------------- + +Project: MoDo +Source: https://github.com/heshandevaka/Trade-Off-MOL/blob/main/LibMTL/LibMTL/weighting/MoDo.py +Used in: src/torchjd/aggregation/_modo.py + +MIT License + +Copyright (c) 2023 Heshan Fernando + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 13e405cb..66d74570 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -41,6 +41,7 @@ Abstract base classes krum.rst mean.rst mgda.rst + modo.rst nash_mtl.rst pcgrad.rst random.rst diff --git a/docs/source/docs/aggregation/modo.rst b/docs/source/docs/aggregation/modo.rst new file mode 100644 index 00000000..98b8d515 --- /dev/null +++ b/docs/source/docs/aggregation/modo.rst @@ -0,0 +1,7 @@ +:hide-toc: + +MoDo +==== + +.. autoclass:: torchjd.aggregation.MoDoWeighting + :members: __call__, reset diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 1814d320..92bbadec 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -53,6 +53,7 @@ from ._mean import Mean, MeanWeighting from ._mgda import MGDA, MGDAWeighting from ._mixins import Stateful +from ._modo import MoDoWeighting from ._nash_mtl import NashMTL from ._pcgrad import PCGrad, PCGradWeighting from ._random import Random, RandomWeighting @@ -87,6 +88,7 @@ "MeanWeighting", "MGDA", "MGDAWeighting", + "MoDoWeighting", "NashMTL", "PCGrad", "PCGradWeighting", diff --git a/src/torchjd/aggregation/_modo.py b/src/torchjd/aggregation/_modo.py new file mode 100644 index 00000000..7c1a2e83 --- /dev/null +++ b/src/torchjd/aggregation/_modo.py @@ -0,0 +1,197 @@ +# Partly adapted from https://github.com/heshandevaka/Trade-Off-MOL — MIT License, Copyright (c) 2023 Heshan Fernando. +# See NOTICES for the full license text. +from __future__ import annotations + +from typing import cast + +import torch +from torch import Tensor + +from torchjd.aggregation._mixins import Stateful, _NonDifferentiable +from torchjd.linalg import Matrix + +from ._weighting_bases import _MatrixWeighting + + +class MoDoWeighting(_MatrixWeighting, Stateful, _NonDifferentiable): + r""" + :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] from `Three-Way + Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance + `_ (JMLR 2024). + + .. warning:: + The input matrix must be :math:`G = J_1 J_2^\top`, computed from two **independent** + mini-batches via :func:`torchjd.autojac.jac`. Using a single-batch Gramian + (:math:`J_1 J_1^\top`) breaks the convergence guarantee. See the usage examples below. + + :param gamma: Learning rate of the task-weight update. Must be positive. + :param rho: Non-negative :math:`\ell_2` regularisation coefficient. + + .. note:: + The Euclidean projection onto the simplex used in the :math:`\lambda` update is adapted from + the `official implementation `_. + + .. admonition:: Example (two batches per step) + + The following example reproduces basic MoDo using two independent mini-batches per step. + This matches MoDo as described in the paper, and the behavior of the official + implementation when ``three_grads`` is ``False``. + + .. testcode:: + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import MoDoWeighting + from torchjd.autojac import jac + + # Generate data (8 batches of 16 examples of dim 5) for the sake of the example. + inputs = torch.randn(8, 16, 5) + targets = torch.randn(8, 16) + + model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) + optimizer = SGD(model.parameters()) + criterion = MSELoss(reduction="none") + weighting = MoDoWeighting(gamma=0.1, rho=0.0) + params = list(model.parameters()) + + # Consume two consecutive (independent) batches per step. + for i in range(len(inputs) // 2): + input_1, input_2 = inputs[2 * i], inputs[2 * i + 1] + target_1, target_2 = targets[2 * i], targets[2 * i + 1] + + # retain_graph=True so both graphs survive for the backward step below. + losses_1 = criterion(model(input_1).squeeze(dim=1), target_1) + jacs_1 = jac(losses_1, params, retain_graph=True) + J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1) + + losses_2 = criterion(model(input_2).squeeze(dim=1), target_2) + jacs_2 = jac(losses_2, params, retain_graph=True) + J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1) + + G = J_1 @ J_2.T + weights = weighting(G) + + # Equation 2.9b: the parameter update uses the mean of both batches' losses. + losses = (losses_1 + losses_2) / 2.0 + losses.backward(weights) + optimizer.step() + optimizer.zero_grad() + + .. admonition:: Example (three batches per step) + + The following example reproduces basic MoDo using three independent mini-batches per step, + keeping the :math:`\lambda` update and the parameter update on separate draws. This matches + the behavior of LibMTL and of the official implementation when ``three_grads`` is ``True``. + + .. testcode:: + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import MoDoWeighting + from torchjd.autojac import jac + + # Generate data (9 batches of 16 examples of dim 5) for the sake of the example. + inputs = torch.randn(9, 16, 5) + targets = torch.randn(9, 16) + + model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) + optimizer = SGD(model.parameters()) + criterion = MSELoss(reduction="none") + weighting = MoDoWeighting(gamma=0.1, rho=0.0) + params = list(model.parameters()) + + # Consume three consecutive (independent) batches per step. + for i in range(len(inputs) // 3): + input_1, input_2, input_3 = inputs[3 * i], inputs[3 * i + 1], inputs[3 * i + 2] + target_1, target_2, target_3 = targets[3 * i], targets[3 * i + 1], targets[3 * i + 2] + + losses_1 = criterion(model(input_1).squeeze(dim=1), target_1) + jacs_1 = jac(losses_1, params) + J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1) + + losses_2 = criterion(model(input_2).squeeze(dim=1), target_2) + jacs_2 = jac(losses_2, params) + J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1) + + G = J_1 @ J_2.T + weights = weighting(G) + + losses_3 = criterion(model(input_3).squeeze(dim=1), target_3) + losses_3.backward(weights) + optimizer.step() + optimizer.zero_grad() + """ + + def __init__(self, gamma: float = 0.1, rho: float = 0.1) -> None: + super().__init__() + self.gamma = gamma + self.rho = rho + self._lambda: Tensor | None = None + self._state_key: tuple[int, torch.dtype, torch.device] | None = None + + @property + def gamma(self) -> float: + return self._gamma + + @gamma.setter + def gamma(self, value: float) -> None: + if value <= 0.0: + raise ValueError(f"Attribute `gamma` must be positive. Found gamma={value!r}.") + self._gamma = value + + @property + def rho(self) -> float: + return self._rho + + @rho.setter + def rho(self, value: float) -> None: + if value < 0.0: + raise ValueError(f"Attribute `rho` must be non-negative. Found rho={value!r}.") + self._rho = value + + def reset(self) -> None: + """Clears the stored task weights so the next forward starts from uniform.""" + + self._lambda = None + self._state_key = None + + def forward(self, matrix: Matrix, /) -> Tensor: + self._ensure_state(matrix) + lambd = cast(Tensor, self._lambda) + + grad = matrix @ lambd + self._rho * lambd + lambd = self._projection2simplex(lambd - self._gamma * grad) + + self._lambda = lambd + return lambd + + @staticmethod + def _projection2simplex(y: Tensor) -> Tensor: + """Euclidean projection of ``y`` onto the probability simplex.""" + + m = len(y) + sorted_y = torch.sort(y, descending=True)[0] + tmpsum = y.new_zeros(()) + tmax_f = (torch.sum(y) - 1.0) / m + for i in range(m - 1): + tmpsum = tmpsum + sorted_y[i] + tmax = (tmpsum - 1.0) / (i + 1.0) + if tmax > sorted_y[i + 1]: + tmax_f = tmax + break + return torch.max(y - tmax_f, y.new_zeros(m)) + + def _ensure_state(self, matrix: Matrix) -> None: + key = (matrix.shape[0], matrix.dtype, matrix.device) + if self._state_key == key and self._lambda is not None: + return + self._lambda = matrix.new_full((matrix.shape[0],), 1.0 / matrix.shape[0]) + self._state_key = key + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(gamma={self.gamma!r}, rho={self.rho!r})" diff --git a/tests/unit/aggregation/test_modo.py b/tests/unit/aggregation/test_modo.py new file mode 100644 index 00000000..f169ea86 --- /dev/null +++ b/tests/unit/aggregation/test_modo.py @@ -0,0 +1,187 @@ +import torch +from pytest import raises +from torch import Tensor +from torch.testing import assert_close +from utils.tensors import randn_, tensor_ + +from torchjd.aggregation._modo import MoDoWeighting + + +def _project_to_simplex(y: Tensor) -> Tensor: + """Reference Euclidean projection onto the probability simplex, used to derive expected + values independently of the implementation.""" + + m = len(y) + sorted_y = torch.sort(y, descending=True)[0] + tmpsum = y.new_zeros(()) + tmax_f = (torch.sum(y) - 1.0) / m + for i in range(m - 1): + tmpsum = tmpsum + sorted_y[i] + tmax = (tmpsum - 1.0) / (i + 1.0) + if tmax > sorted_y[i + 1]: + tmax_f = tmax + break + return torch.max(y - tmax_f, y.new_zeros(m)) + + +def test_representations() -> None: + W = MoDoWeighting(gamma=0.1, rho=0.05) + assert repr(W) == "MoDoWeighting(gamma=0.1, rho=0.05)" + + +def test_reset_restores_first_step_behavior() -> None: + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + G = J1 @ J2.T + W = MoDoWeighting(gamma=0.1) + first = W(G) + W(G) + W.reset() + assert_close(first, W(G)) + + +def test_gamma_setter_accepts_valid() -> None: + W = MoDoWeighting() + W.gamma = 0.01 + assert W.gamma == 0.01 + W.gamma = 0.1 + assert W.gamma == 0.1 + W.gamma = 1.0 + assert W.gamma == 1.0 + + +def test_gamma_setter_rejects_non_positive() -> None: + W = MoDoWeighting() + with raises(ValueError, match="gamma"): + W.gamma = 0.0 + with raises(ValueError, match="gamma"): + W.gamma = -0.1 + + +def test_rho_setter_accepts_valid() -> None: + W = MoDoWeighting() + W.rho = 0.0 + assert W.rho == 0.0 + W.rho = 0.1 + assert W.rho == 0.1 + + +def test_rho_setter_rejects_negative() -> None: + W = MoDoWeighting() + with raises(ValueError, match="rho"): + W.rho = -0.1 + + +def test_output_lies_on_simplex() -> None: + """The softmax projection ensures the weights sum to 1 and are non-negative.""" + + J1 = randn_((4, 10)) + J2 = randn_((4, 10)) + G = J1 @ J2.T + W = MoDoWeighting(gamma=0.1, rho=0.05) + weights = W(G) + assert weights.shape == (4,) + assert (weights >= 0).all() + assert_close(weights.sum(), tensor_(1.0)) + + +def test_update_recurrence() -> None: + """Verify one step of the softmax-projected gradient update by hand.""" + + gamma = 0.1 + rho = 0.05 + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + G = J1 @ J2.T + m = J1.shape[0] + + W = MoDoWeighting(gamma=gamma, rho=rho) + lambda_0 = tensor_([1.0 / m] * m) + grad = G @ lambda_0 + rho * lambda_0 + expected = _project_to_simplex(lambda_0 - gamma * grad) + + assert_close(W(G), expected) + + +def test_two_consecutive_steps() -> None: + """Verify two consecutive steps of the softmax-projected gradient update.""" + + gamma = 0.1 + rho = 0.0 + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + J3 = randn_((3, 8)) + J4 = randn_((3, 8)) + G1 = J1 @ J2.T + G2 = J3 @ J4.T + m = J1.shape[0] + + W = MoDoWeighting(gamma=gamma, rho=rho) + + lambda_0 = tensor_([1.0 / m] * m) + grad_1 = G1 @ lambda_0 + rho * lambda_0 + lambda_1 = _project_to_simplex(lambda_0 - gamma * grad_1) + + grad_2 = G2 @ lambda_1 + rho * lambda_1 + lambda_2 = _project_to_simplex(lambda_1 - gamma * grad_2) + + assert_close(W(G1), lambda_1) + assert_close(W(G2), lambda_2) + + +def test_changing_m_auto_resets() -> None: + """When the number of objectives changes, the state is re-initialised to uniform.""" + + W = MoDoWeighting(gamma=0.1) + W(randn_((3, 8)) @ randn_((3, 8)).T) + # After a state-resetting call with m=2, the first output should equal the uniform step's output. + fresh = MoDoWeighting(gamma=0.1) + J1 = randn_((2, 8)) + J2 = randn_((2, 8)) + G = J1 @ J2.T + assert_close(W(G), fresh(G)) + + +def test_non_differentiable() -> None: + """The _NonDifferentiable mixin must prevent autograd graph construction.""" + + G = randn_((3, 8)) @ randn_((3, 8)).T + G.requires_grad_(True) + W = MoDoWeighting() + weights = W(G) + assert not weights.requires_grad + + +def test_non_symmetric_input() -> None: + """MoDoWeighting must accept and correctly process a non-symmetric cross-batch matrix.""" + + gamma = 0.1 + rho = 0.05 + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + G = J1 @ J2.T # not symmetric, not PSD in general + m = J1.shape[0] + + W = MoDoWeighting(gamma=gamma, rho=rho) + lambda_0 = tensor_([1.0 / m] * m) + grad = G @ lambda_0 + rho * lambda_0 + expected = _project_to_simplex(lambda_0 - gamma * grad) + + assert_close(W(G), expected) + assert W(G).shape == (m,) + assert (W(G) >= 0).all() + + +def test_projection2simplex_known_values() -> None: + """The simplex projection matches hand-computed Euclidean projections.""" + + # Already-positive input: the deficit (1 - sum) is spread equally, no clamping. + assert_close( + MoDoWeighting._projection2simplex(tensor_([0.5, 0.1, 0.1])), + tensor_([0.6, 0.2, 0.2]), + ) + # Input with a negative entry: it gets clamped to zero. + assert_close( + MoDoWeighting._projection2simplex(tensor_([1.0, 0.0, -0.5])), + tensor_([1.0, 0.0, 0.0]), + )