diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index c15d5980f..dd427bd0d 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -42,5 +42,6 @@ Abstract base classes nash_mtl.rst pcgrad.rst random.rst + stch.rst sum.rst trimmed_mean.rst diff --git a/docs/source/docs/aggregation/stch.rst b/docs/source/docs/aggregation/stch.rst new file mode 100644 index 000000000..57154a859 --- /dev/null +++ b/docs/source/docs/aggregation/stch.rst @@ -0,0 +1,14 @@ +:hide-toc: + +STCH +==== + +.. autoclass:: torchjd.aggregation.STCH + :members: + :undoc-members: + :exclude-members: forward + +.. autoclass:: torchjd.aggregation.STCHWeighting + :members: + :undoc-members: + :exclude-members: forward diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 9eed9bf7e..b8a30fe20 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -72,6 +72,7 @@ from ._mgda import MGDA, MGDAWeighting from ._pcgrad import PCGrad, PCGradWeighting from ._random import Random, RandomWeighting +from ._stch import STCH, STCHWeighting from ._sum import Sum, SumWeighting from ._trimmed_mean import TrimmedMean from ._upgrad import UPGrad, UPGradWeighting @@ -104,6 +105,8 @@ "PCGradWeighting", "Random", "RandomWeighting", + "STCH", + "STCHWeighting", "Sum", "SumWeighting", "TrimmedMean", diff --git a/src/torchjd/aggregation/_stch.py b/src/torchjd/aggregation/_stch.py new file mode 100644 index 000000000..5d9104d3a --- /dev/null +++ b/src/torchjd/aggregation/_stch.py @@ -0,0 +1,225 @@ +# The code of this file was adapted from +# https://github.com/Xi-L/STCH/blob/main/STCH_MTL/LibMTL/weighting/STCH.py. +# It is therefore also subject to the following license. +# +# MIT License +# +# Copyright (c) 2024 Xi Lin +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch +from torch import Tensor + +from torchjd._linalg import PSDMatrix + +from ._aggregator_bases import GramianWeightedAggregator +from ._weighting_bases import Weighting + + +class STCH(GramianWeightedAggregator): + r""" + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the Smooth Tchebycheff + scalarization as proposed in `Smooth Tchebycheff Scalarization for Multi-Objective Optimization + `_. + + This aggregator uses the log-sum-exp (smooth maximum) function to compute weights that focus + more on poorly performing tasks (tasks with larger gradient norms). The ``mu`` parameter + controls the smoothness: as ``mu`` approaches 0, the weights converge to a hard maximum + (focusing entirely on the worst task); as ``mu`` increases, the weights approach uniform + averaging. + + :param mu: The smoothness parameter for the log-sum-exp. Smaller values give more weight to the + worst-performing task. Must be positive. + :param warmup_steps: Optional number of steps for the warmup phase. During warmup, gradient + norms are accumulated to compute a nadir vector for normalization. If ``None`` (default), + no warmup is performed and raw gradient norms are used directly. + :param eps: A small value to avoid numerical issues in log computations. + + .. warning:: + If ``warmup_steps`` is set, this aggregator becomes stateful. Its output will depend not + only on the input matrix, but also on its internal state (previously seen matrices). It + should be reset between experiments using the :meth:`reset` method. + + .. note:: + The original STCH algorithm operates on loss values. This implementation adapts it for + gradient-based aggregation using gradient norms (derived from the Gramian diagonal) as + proxies for task performance. + + Example + ------- + + >>> from torch import tensor + >>> from torchjd.aggregation import STCH + >>> + >>> A = STCH(mu=1.0) + >>> J = tensor([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]]) + >>> A(J) + tensor([1.8188, 1.0000, 1.0000]) + + .. note:: + This implementation was adapted from the `official implementation + `_. + """ + + def __init__( + self, + mu: float = 1.0, + warmup_steps: int | None = None, + eps: float = 1e-20, + ): + if mu <= 0.0: + raise ValueError(f"Parameter `mu` should be a positive float. Found `mu = {mu}`.") + + if warmup_steps is not None and warmup_steps < 1: + raise ValueError( + f"Parameter `warmup_steps` should be a positive integer or None. " + f"Found `warmup_steps = {warmup_steps}`." + ) + + stch_weighting = STCHWeighting(mu=mu, warmup_steps=warmup_steps, eps=eps) + super().__init__(stch_weighting) + + self._mu = mu + self._warmup_steps = warmup_steps + self._eps = eps + self._stch_weighting = stch_weighting + + def reset(self) -> None: + """Resets the internal state of the algorithm (step counter and accumulated nadir).""" + self._stch_weighting.reset() + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(mu={self._mu}, warmup_steps={self._warmup_steps}, " + f"eps={self._eps})" + ) + + def __str__(self) -> str: + mu_str = str(self._mu).rstrip("0").rstrip(".") + return f"STCH(mu={mu_str})" + + +class STCHWeighting(Weighting[PSDMatrix]): + r""" + :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of + :class:`~torchjd.aggregation.STCH`. + + The weights are computed using the Smooth Tchebycheff scalarization formula: + + .. math:: + + w_i = \frac{\exp\left(\frac{\log(g_i / z_i) - \max_j \log(g_j / z_j)}{\mu}\right)} + {\sum_k \exp\left(\frac{\log(g_k / z_k) - \max_j \log(g_j / z_j)}{\mu}\right)} + + where :math:`g_i` is the gradient norm for task :math:`i` (computed as :math:`\sqrt{G_{ii}}` + from the Gramian), :math:`z_i` is the nadir value for task :math:`i`, and :math:`\mu` is the + smoothness parameter. + + :param mu: The smoothness parameter for the log-sum-exp. Must be positive. + :param warmup_steps: Optional number of steps for the warmup phase. During warmup, gradient + norms are accumulated to compute a nadir vector. If ``None``, no warmup is performed. + :param eps: A small value to avoid numerical issues in log computations. + + .. warning:: + If ``warmup_steps`` is set, this weighting becomes stateful. During warmup, it returns + uniform weights while accumulating gradient norms. After warmup, the accumulated average + is used as the nadir vector for normalization. + """ + + def __init__( + self, + mu: float = 1.0, + warmup_steps: int | None = None, + eps: float = 1e-20, + ): + super().__init__() + + if mu <= 0.0: + raise ValueError(f"Parameter `mu` should be a positive float. Found `mu = {mu}`.") + + if warmup_steps is not None and warmup_steps < 1: + raise ValueError( + f"Parameter `warmup_steps` should be a positive integer or None. " + f"Found `warmup_steps = {warmup_steps}`." + ) + + self.mu = mu + self.warmup_steps = warmup_steps + self.eps = eps + + # Internal state for warmup + self.step = 0 + self.nadir_accumulator: Tensor | None = None + self.nadir_vector: Tensor | None = None + + def reset(self) -> None: + """Resets the internal state of the algorithm.""" + self.step = 0 + self.nadir_accumulator = None + self.nadir_vector = None + + def forward(self, gramian: PSDMatrix) -> Tensor: + device = gramian.device + dtype = gramian.dtype + m = gramian.shape[0] + + # Compute gradient norms from Gramian diagonal (sqrt of diagonal) + grad_norms = torch.sqrt(torch.diag(gramian).clamp(min=self.eps)) + + # Handle warmup phase if warmup_steps is set + if self.warmup_steps is not None: + if self.step < self.warmup_steps: + # During warmup: accumulate gradient norms and return uniform weights + if self.nadir_accumulator is None: + self.nadir_accumulator = grad_norms.detach().clone() + else: + self.nadir_accumulator = ( + self.nadir_accumulator.to(device=device, dtype=dtype) + grad_norms.detach() + ) + + self.step += 1 + + # Return uniform weights during warmup + return torch.full(size=[m], fill_value=1.0 / m, device=device, dtype=dtype) + + elif self.nadir_vector is None: + # First step after warmup: compute nadir vector from accumulated values + self.nadir_vector = self.nadir_accumulator / self.warmup_steps # type: ignore + self.step += 1 + else: + self.step += 1 + + # Normalize by nadir vector if available (after warmup) + if self.nadir_vector is not None: + nadir = self.nadir_vector.to(device=device, dtype=dtype) + normalized = grad_norms / nadir.clamp(min=self.eps) + else: + normalized = grad_norms + + # Apply log and compute smooth max weights using log-sum-exp trick for numerical stability + log_normalized = torch.log(normalized + self.eps) + max_log = torch.max(log_normalized) + reg_log = (log_normalized - max_log) / self.mu + + # Softmax weights + exp_reg = torch.exp(reg_log) + weights = exp_reg / exp_reg.sum() + + return weights diff --git a/tests/unit/aggregation/test_stch.py b/tests/unit/aggregation/test_stch.py new file mode 100644 index 000000000..9dfa33d36 --- /dev/null +++ b/tests/unit/aggregation/test_stch.py @@ -0,0 +1,246 @@ +import torch +from pytest import mark, raises +from torch import Tensor + +from torchjd.aggregation import STCH, STCHWeighting + +from ._asserts import assert_expected_structure, assert_permutation_invariant +from ._inputs import scaled_matrices, typical_matrices + +aggregators = [ + STCH(), + STCH(mu=0.1), + STCH(mu=0.5), + STCH(mu=2.0), + STCH(mu=10.0), +] +scaled_pairs = [(aggregator, matrix) for aggregator in aggregators for matrix in scaled_matrices] +typical_pairs = [(aggregator, matrix) for aggregator in aggregators for matrix in typical_matrices] + + +@mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) +def test_expected_structure(aggregator: STCH, matrix: Tensor): + assert_expected_structure(aggregator, matrix) + + +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_permutation_invariant(aggregator: STCH, matrix: Tensor): + assert_permutation_invariant(aggregator, matrix) + + +def test_representations(): + A = STCH(mu=1.0) + assert repr(A) == "STCH(mu=1.0, warmup_steps=None, eps=1e-20)" + assert str(A) == "STCH(mu=1)" + + A = STCH(mu=0.5) + assert repr(A) == "STCH(mu=0.5, warmup_steps=None, eps=1e-20)" + assert str(A) == "STCH(mu=0.5)" + + A = STCH(mu=1.0, warmup_steps=100) + assert repr(A) == "STCH(mu=1.0, warmup_steps=100, eps=1e-20)" + assert str(A) == "STCH(mu=1)" + + +def test_invalid_mu(): + with raises(ValueError, match=r"Parameter `mu` should be a positive float"): + STCH(mu=0.0) + + with raises(ValueError, match=r"Parameter `mu` should be a positive float"): + STCH(mu=-1.0) + + +def test_invalid_warmup_steps(): + with raises(ValueError, match=r"Parameter `warmup_steps` should be a positive integer or None"): + STCH(warmup_steps=0) + + with raises(ValueError, match=r"Parameter `warmup_steps` should be a positive integer or None"): + STCH(warmup_steps=-1) + + +def test_weights_sum_to_one(): + """Test that the weights computed by STCHWeighting sum to 1.""" + weighting = STCHWeighting(mu=1.0) + gramian = torch.tensor([[4.0, 2.0], [2.0, 9.0]]) # Gradient norms are 2 and 3 + weights = weighting(gramian) + assert torch.isclose(weights.sum(), torch.tensor(1.0), atol=1e-6) + + +def test_small_mu_focuses_on_max(): + """Test that small mu focuses weights on the task with largest gradient norm.""" + gramian = torch.tensor([[1.0, 0.0], [0.0, 100.0]]) # Norms are 1 and 10 + + weighting_small_mu = STCHWeighting(mu=0.01) + weights = weighting_small_mu(gramian) + + # With very small mu, the weight should be almost entirely on the second task + assert weights[1] > 0.99 + + +def test_large_mu_approaches_uniform(): + """Test that large mu approaches uniform weighting.""" + gramian = torch.tensor([[1.0, 0.0], [0.0, 100.0]]) # Very different norms + + weighting_large_mu = STCHWeighting(mu=100.0) + weights = weighting_large_mu(gramian) + + # With very large mu, weights should approach uniform [0.5, 0.5] + assert torch.isclose(weights[0], torch.tensor(0.5), atol=0.1) + assert torch.isclose(weights[1], torch.tensor(0.5), atol=0.1) + + +def test_weighting_invalid_mu(): + with raises(ValueError, match=r"Parameter `mu` should be a positive float"): + STCHWeighting(mu=0.0) + + with raises(ValueError, match=r"Parameter `mu` should be a positive float"): + STCHWeighting(mu=-1.0) + + +def test_weighting_invalid_warmup_steps(): + with raises(ValueError, match=r"Parameter `warmup_steps` should be a positive integer or None"): + STCHWeighting(warmup_steps=0) + + with raises(ValueError, match=r"Parameter `warmup_steps` should be a positive integer or None"): + STCHWeighting(warmup_steps=-1) + + +# Tests for warmup functionality + + +def test_warmup_returns_uniform_during_warmup(): + """Test that during warmup, uniform weights are returned.""" + weighting = STCHWeighting(mu=1.0, warmup_steps=3) + gramian = torch.tensor([[1.0, 0.0], [0.0, 100.0]]) # Very different norms + + # During warmup, weights should be uniform regardless of gradient norms + for _ in range(3): + weights = weighting(gramian) + assert torch.isclose(weights[0], torch.tensor(0.5), atol=1e-6) + assert torch.isclose(weights[1], torch.tensor(0.5), atol=1e-6) + + +def test_warmup_uses_nadir_after_warmup(): + """Test that after warmup, the nadir vector is used for normalization.""" + weighting = STCHWeighting(mu=1.0, warmup_steps=2) + + # During warmup: accumulate gradient norms + gramian1 = torch.tensor([[4.0, 0.0], [0.0, 16.0]]) # Norms: [2, 4] + gramian2 = torch.tensor([[4.0, 0.0], [0.0, 16.0]]) # Norms: [2, 4] + + weights1 = weighting(gramian1) # Step 1: warmup + weights2 = weighting(gramian2) # Step 2: warmup + + # During warmup, should return uniform weights + assert torch.isclose(weights1[0], torch.tensor(0.5), atol=1e-6) + assert torch.isclose(weights2[0], torch.tensor(0.5), atol=1e-6) + + # After warmup, nadir should be [2, 4] (average of accumulated norms) + # Now with a gramian that has different norms + gramian3 = torch.tensor([[4.0, 0.0], [0.0, 4.0]]) # Norms: [2, 2] + weights3 = weighting(gramian3) # Step 3: after warmup + + # Normalized: [2/2, 2/4] = [1, 0.5] + # log: [0, -0.693], max=0, reg: [0, -0.693] + # exp: [1, 0.5], weights should favor first task + assert weights3[0] > weights3[1] + + +def test_reset_clears_state(): + """Test that reset() clears the warmup state.""" + weighting = STCHWeighting(mu=1.0, warmup_steps=2) + + gramian = torch.tensor([[4.0, 0.0], [0.0, 16.0]]) + + # Go through warmup + weighting(gramian) + weighting(gramian) + weighting(gramian) # After warmup + + assert weighting.step == 3 + assert weighting.nadir_vector is not None + + # Reset + weighting.reset() + + assert weighting.step == 0 + assert weighting.nadir_vector is None + assert weighting.nadir_accumulator is None + + # Should be in warmup again + weights = weighting(gramian) + assert torch.isclose(weights[0], torch.tensor(0.5), atol=1e-6) + + +def test_aggregator_reset(): + """Test that STCH.reset() properly resets the weighting state.""" + A = STCH(mu=1.0, warmup_steps=2) + + matrix = torch.tensor([[2.0, 0.0], [0.0, 4.0]]) + + # Go through warmup + A(matrix) + A(matrix) + A(matrix) + + # Reset through aggregator + A.reset() + + # Weighting should be reset + assert A._stch_weighting.step == 0 + assert A._stch_weighting.nadir_vector is None + + +def test_no_warmup_when_warmup_steps_none(): + """Test that no warmup occurs when warmup_steps is None.""" + weighting = STCHWeighting(mu=1.0, warmup_steps=None) + gramian = torch.tensor([[1.0, 0.0], [0.0, 100.0]]) + + # Should immediately use STCH weighting (not uniform) + weights = weighting(gramian) + + # With mu=1.0 and norms [1, 10], the second task should have higher weight + assert weights[1] > weights[0] + + +def test_warmup_step_counter(): + """Test that the step counter increments correctly.""" + weighting = STCHWeighting(mu=1.0, warmup_steps=3) + gramian = torch.tensor([[4.0, 0.0], [0.0, 16.0]]) + + assert weighting.step == 0 + + weighting(gramian) + assert weighting.step == 1 + + weighting(gramian) + assert weighting.step == 2 + + weighting(gramian) + assert weighting.step == 3 + + weighting(gramian) # First step after warmup: nadir_vector gets computed + assert weighting.step == 4 + + weighting(gramian) # Steady-state: nadir_vector is already set + assert weighting.step == 5 + + +def test_warmup_with_varying_gramians(): + """Test warmup with different gramians to verify accumulation.""" + weighting = STCHWeighting(mu=1.0, warmup_steps=2) + + gramian1 = torch.tensor([[1.0, 0.0], [0.0, 4.0]]) # Norms: [1, 2] + gramian2 = torch.tensor([[9.0, 0.0], [0.0, 16.0]]) # Norms: [3, 4] + + weighting(gramian1) # Step 1: warmup + weighting(gramian2) # Step 2: warmup (completes warmup) + + # nadir_vector is computed on first call AFTER warmup + gramian3 = torch.tensor([[4.0, 0.0], [0.0, 4.0]]) + weighting(gramian3) # Step 3: computes nadir_vector and uses it + + # Nadir should be average: [(1+3)/2, (2+4)/2] = [2, 3] + expected_nadir = torch.tensor([2.0, 3.0]) + assert weighting.nadir_vector is not None + assert torch.allclose(weighting.nadir_vector, expected_nadir, atol=1e-6)