Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ changelog does not include internal changes that do not affect the user.

### Added

- Added `MoDoWeighting` from [Three-Way Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance](https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf) (JMLR 2024). It is a stateful `Weighting` that maintains task weights across calls via a softmax-projected gradient step on a cross-batch matrix `G = J_1 @ J_2.T`, computed from two independent mini-batches using `autojac.jac`.
- Added `GeometricMean` (also known as GLS) studied in [MultiNet++: Multi-Stream Feature
Aggregation and Geometric Loss Strategy for Multi-Task
Learning](https://openaccess.thecvf.com/content_CVPRW_2019/papers/WAD/Chennupati_MultiNet_Multi-Stream_Feature_Aggregation_and_Geometric_Loss_Strategy_for_Multi-Task_CVPRW_2019_paper.pdf),
Expand Down
1 change: 1 addition & 0 deletions docs/source/docs/aggregation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Abstract base classes
krum.rst
mean.rst
mgda.rst
modo.rst
nash_mtl.rst
pcgrad.rst
random.rst
Expand Down
7 changes: 7 additions & 0 deletions docs/source/docs/aggregation/modo.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:hide-toc:

MoDo
====

.. autoclass:: torchjd.aggregation.MoDoWeighting
:members: __call__, reset
2 changes: 2 additions & 0 deletions src/torchjd/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from ._mean import Mean, MeanWeighting
from ._mgda import MGDA, MGDAWeighting
from ._mixins import Stateful
from ._modo import MoDoWeighting
from ._nash_mtl import NashMTL
from ._pcgrad import PCGrad, PCGradWeighting
from ._random import Random, RandomWeighting
Expand Down Expand Up @@ -87,6 +88,7 @@
"MeanWeighting",
"MGDA",
"MGDAWeighting",
"MoDoWeighting",
"NashMTL",
"PCGrad",
"PCGradWeighting",
Expand Down
175 changes: 175 additions & 0 deletions src/torchjd/aggregation/_modo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from __future__ import annotations

from typing import cast

import torch
from torch import Tensor

from torchjd.aggregation._mixins import Stateful, _NonDifferentiable
from torchjd.linalg import Matrix

from ._weighting_bases import _MatrixWeighting


class MoDoWeighting(_MatrixWeighting, Stateful, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] from `Three-Way
Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance
<https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf>`_ (JMLR 2024).

.. warning::
The input matrix must be :math:`G = J_1 J_2^\top`, computed from two **independent**
mini-batches via :func:`torchjd.autojac.jac`. Using a single-batch Gramian
(:math:`J_1 J_1^\top`) breaks the convergence guarantee. See the usage examples below.

:param gamma: Learning rate of the task-weight update. Must be positive.
:param rho: Non-negative :math:`\ell_2` regularisation coefficient.

.. admonition:: Example (two batches per step)

The following example reproduces basic MoDo using two independent mini-batches per step.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could add that this is MoDo as described in the paper, and it's the behavior of the official implementation when three_grads is False.

This matches MoDo as described in the paper, and the behavior of the official
implementation when ``three_grads`` is ``False``.

.. testcode::

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import MoDoWeighting
from torchjd.autojac import jac

# Generate data (8 batches of 16 examples of dim 5) for the sake of the example.
inputs = torch.randn(8, 16, 5)
targets = torch.randn(8, 16)

model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1))
optimizer = SGD(model.parameters())
criterion = MSELoss(reduction="none")
weighting = MoDoWeighting(gamma=0.1, rho=0.0)
params = list(model.parameters())

# Consume two consecutive (independent) batches per step.
for i in range(len(inputs) // 2):
input_1, input_2 = inputs[2 * i], inputs[2 * i + 1]
target_1, target_2 = targets[2 * i], targets[2 * i + 1]

# retain_graph=True so both graphs survive for the backward step below.
losses_1 = criterion(model(input_1).squeeze(dim=1), target_1)
jacs_1 = jac(losses_1, params, retain_graph=True)
J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1)

losses_2 = criterion(model(input_2).squeeze(dim=1), target_2)
jacs_2 = jac(losses_2, params, retain_graph=True)
J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1)

G = J_1 @ J_2.T
weights = weighting(G)

# Equation 2.9b: the parameter update uses the mean of both batches' losses.
losses = (losses_1 + losses_2) / 2.0
losses.backward(weights)
optimizer.step()
optimizer.zero_grad()

.. admonition:: Example (three batches per step)

The following example reproduces basic MoDo using three independent mini-batches per step,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could add that this is the behavior of MoDo in LibMTL and in the official implementation when three_grads is True.

keeping the :math:`\lambda` update and the parameter update on separate draws. This matches
the behavior of LibMTL and of the official implementation when ``three_grads`` is ``True``.

.. testcode::

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import MoDoWeighting
from torchjd.autojac import jac

# Generate data (9 batches of 16 examples of dim 5) for the sake of the example.
inputs = torch.randn(9, 16, 5)
targets = torch.randn(9, 16)

model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1))
optimizer = SGD(model.parameters())
criterion = MSELoss(reduction="none")
weighting = MoDoWeighting(gamma=0.1, rho=0.0)
params = list(model.parameters())

# Consume three consecutive (independent) batches per step.
for i in range(len(inputs) // 3):
input_1, input_2, input_3 = inputs[3 * i], inputs[3 * i + 1], inputs[3 * i + 2]
target_1, target_2, target_3 = targets[3 * i], targets[3 * i + 1], targets[3 * i + 2]

losses_1 = criterion(model(input_1).squeeze(dim=1), target_1)
jacs_1 = jac(losses_1, params)
J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1)

losses_2 = criterion(model(input_2).squeeze(dim=1), target_2)
jacs_2 = jac(losses_2, params)
J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1)

G = J_1 @ J_2.T
weights = weighting(G)

losses_3 = criterion(model(input_3).squeeze(dim=1), target_3)
losses_3.backward(weights)
optimizer.step()
optimizer.zero_grad()
"""

def __init__(self, gamma: float = 0.1, rho: float = 0.1) -> None:
super().__init__()
self.gamma = gamma
self.rho = rho
self._lambda: Tensor | None = None
self._state_key: tuple[int, torch.dtype, torch.device] | None = None

@property
def gamma(self) -> float:
return self._gamma

@gamma.setter
def gamma(self, value: float) -> None:
if value <= 0.0:
raise ValueError(f"Attribute `gamma` must be positive. Found gamma={value!r}.")
self._gamma = value

@property
def rho(self) -> float:
return self._rho

@rho.setter
def rho(self, value: float) -> None:
if value < 0.0:
raise ValueError(f"Attribute `rho` must be non-negative. Found rho={value!r}.")
self._rho = value

def reset(self) -> None:
"""Clears the stored task weights so the next forward starts from uniform."""

self._lambda = None
self._state_key = None

def forward(self, matrix: Matrix, /) -> Tensor:
self._ensure_state(matrix)
lambd = cast(Tensor, self._lambda)

grad = matrix @ lambd + self._rho * lambd
lambd = torch.softmax(lambd - self._gamma * grad, dim=-1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there was some confusion on discord when we talked about how to project onto the simplex. We all thought that the official implementation was using a softmax, but it (and LibMTL) actually uses:

    def _projection2simplex(self, y):
        m = len(y)
        sorted_y = torch.sort(y, descending=True)[0]
        tmpsum = 0.0
        tmax_f = (torch.sum(y) - 1.0)/m
        for i in range(m-1):
            tmpsum+= sorted_y[i]
            tmax = (tmpsum - 1)/ (i+1.0)
            if tmax > sorted_y[i+1]:
                tmax_f = tmax
                break
        return torch.max(y - tmax_f, torch.zeros(m).to(y.device))

Should we use this way of projecting @PierreQuinton ?

If we do that, we'll need to say that parts of this file were adapted from the official implementation, add a link to it, and add a notice in NOTICES @KhusPatel4450.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I know what happened now, the code that I was told to read was from Rasa's MoCo.py and that used torch.softmax, but yeah now I see that it uses this.

I personally think we should follow this


self._lambda = lambd
return lambd

def _ensure_state(self, matrix: Matrix) -> None:
key = (matrix.shape[0], matrix.dtype, matrix.device)
if self._state_key == key and self._lambda is not None:
return
self._lambda = matrix.new_full((matrix.shape[0],), 1.0 / matrix.shape[0])
self._state_key = key

def __repr__(self) -> str:
return f"{self.__class__.__name__}(gamma={self.gamma!r}, rho={self.rho!r})"
154 changes: 154 additions & 0 deletions tests/unit/aggregation/test_modo.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To better reflect the new usage, I think we should use G = J1 @ J2.T instead of G = J @ J.T in test_reset_restores_first_step_behavior, test_output_lies_on_simplex, test_update_recurrence and test_changing_m_auto_resets.

Similarly, we should use G1 = J1 @ J2.T and G2 = J3 @ J4.T in test_two_consecutive_steps.

Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import torch
from pytest import raises
from torch.testing import assert_close
from utils.tensors import randn_, tensor_

from torchjd.aggregation._modo import MoDoWeighting


def test_representations() -> None:
W = MoDoWeighting(gamma=0.1, rho=0.05)
assert repr(W) == "MoDoWeighting(gamma=0.1, rho=0.05)"


def test_reset_restores_first_step_behavior() -> None:
J1 = randn_((3, 8))
J2 = randn_((3, 8))
G = J1 @ J2.T
W = MoDoWeighting(gamma=0.1)
first = W(G)
W(G)
W.reset()
assert_close(first, W(G))


def test_gamma_setter_accepts_valid() -> None:
W = MoDoWeighting()
W.gamma = 0.01
assert W.gamma == 0.01
W.gamma = 0.1
assert W.gamma == 0.1
W.gamma = 1.0
assert W.gamma == 1.0


def test_gamma_setter_rejects_non_positive() -> None:
W = MoDoWeighting()
with raises(ValueError, match="gamma"):
W.gamma = 0.0
with raises(ValueError, match="gamma"):
W.gamma = -0.1


def test_rho_setter_accepts_valid() -> None:
W = MoDoWeighting()
W.rho = 0.0
assert W.rho == 0.0
W.rho = 0.1
assert W.rho == 0.1


def test_rho_setter_rejects_negative() -> None:
W = MoDoWeighting()
with raises(ValueError, match="rho"):
W.rho = -0.1


def test_output_lies_on_simplex() -> None:
"""The softmax projection ensures the weights sum to 1 and are non-negative."""

J1 = randn_((4, 10))
J2 = randn_((4, 10))
G = J1 @ J2.T
W = MoDoWeighting(gamma=0.1, rho=0.05)
weights = W(G)
assert weights.shape == (4,)
assert (weights >= 0).all()
assert_close(weights.sum(), tensor_(1.0))


def test_update_recurrence() -> None:
"""Verify one step of the softmax-projected gradient update by hand."""

gamma = 0.1
rho = 0.05
J1 = randn_((3, 8))
J2 = randn_((3, 8))
G = J1 @ J2.T
m = J1.shape[0]

W = MoDoWeighting(gamma=gamma, rho=rho)
lambda_0 = tensor_([1.0 / m] * m)
grad = G @ lambda_0 + rho * lambda_0
expected = torch.softmax(lambda_0 - gamma * grad, dim=-1)

assert_close(W(G), expected)


def test_two_consecutive_steps() -> None:
"""Verify two consecutive steps of the softmax-projected gradient update."""

gamma = 0.1
rho = 0.0
J1 = randn_((3, 8))
J2 = randn_((3, 8))
J3 = randn_((3, 8))
J4 = randn_((3, 8))
G1 = J1 @ J2.T
G2 = J3 @ J4.T
m = J1.shape[0]

W = MoDoWeighting(gamma=gamma, rho=rho)

lambda_0 = tensor_([1.0 / m] * m)
grad_1 = G1 @ lambda_0 + rho * lambda_0
lambda_1 = torch.softmax(lambda_0 - gamma * grad_1, dim=-1)

grad_2 = G2 @ lambda_1 + rho * lambda_1
lambda_2 = torch.softmax(lambda_1 - gamma * grad_2, dim=-1)

assert_close(W(G1), lambda_1)
assert_close(W(G2), lambda_2)


def test_changing_m_auto_resets() -> None:
"""When the number of objectives changes, the state is re-initialised to uniform."""

W = MoDoWeighting(gamma=0.1)
W(randn_((3, 8)) @ randn_((3, 8)).T)
# After a state-resetting call with m=2, the first output should equal the uniform step's output.
fresh = MoDoWeighting(gamma=0.1)
J1 = randn_((2, 8))
J2 = randn_((2, 8))
G = J1 @ J2.T
assert_close(W(G), fresh(G))


def test_non_differentiable() -> None:
"""The _NonDifferentiable mixin must prevent autograd graph construction."""

G = randn_((3, 8)) @ randn_((3, 8)).T
G.requires_grad_(True)
W = MoDoWeighting()
weights = W(G)
assert not weights.requires_grad


def test_non_symmetric_input() -> None:
"""MoDoWeighting must accept and correctly process a non-symmetric cross-batch matrix."""

gamma = 0.1
rho = 0.05
J1 = randn_((3, 8))
J2 = randn_((3, 8))
G = J1 @ J2.T # not symmetric, not PSD in general
m = J1.shape[0]

W = MoDoWeighting(gamma=gamma, rho=rho)
lambda_0 = tensor_([1.0 / m] * m)
grad = G @ lambda_0 + rho * lambda_0
expected = torch.softmax(lambda_0 - gamma * grad, dim=-1)

assert_close(W(G), expected)
assert W(G).shape == (m,)
assert (W(G) >= 0).all()
Loading