Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,35 @@ jobs:
working-directory: docs
run: uv run make dirhtml

check-links:
name: Check no link is broken
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4

# This will restore the cache for the current commit if it exists, or the most recent lychee
# cache otherwise (including those saved for the main branch). It will also save the cache for
# the current commit if none existed for it, and only if the link check succeeded. We don't
# want to save a cache when the action failed, because the reason for failure might be
# temporary (rate limiting, network issue, etc.), and we always want to retry those links
# everytime this action is run.
- name: Restore lychee cache
uses: actions/cache@v4
with:
path: .lycheecache
key: cache-lychee-${{ github.sha }}
restore-keys: cache-lychee-

- name: Run lychee
uses: lycheeverse/lychee-action@v2
with:
args: --verbose --no-progress --cache --max-cache-age 1d "." --exclude-path "docs/source/_templates/page.html"
fail: true
env:
# This reduces false positives due to rate limits
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

mypy:
name: Run mypy
runs-on: ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ Informed Neural Networks](https://arxiv.org/pdf/2408.11104).
- `Aggregator` base class to aggregate Jacobian matrices.
- `AlignedMTL` from [Independent Component
Alignment for Multi-Task Learning](
https://openaccess.thecvf.com/content/CVPR2023/papers/Senushkin_Independent_Component_Alignment_for_Multi-Task_Learning_CVPR_2023_paper.pdf>).
https://openaccess.thecvf.com/content/CVPR2023/papers/Senushkin_Independent_Component_Alignment_for_Multi-Task_Learning_CVPR_2023_paper.pdf).
- `CAGrad` from [Conflict-Averse Gradient Descent for Multi-task
Learning](https://arxiv.org/pdf/2110.14048.pdf).
- `Constant` to aggregate with constant weights.
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ There are two main ways to use TorchJD. The first one is to replace the usual ca
[`torchjd.autojac.mtl_backward`](https://torchjd.org/stable/docs/autojac/mtl_backward/), depending
on the use-case. This will compute the Jacobian of the vector of losses with respect to the model
parameters, and aggregate it with the specified
[`Aggregator`](https://torchjd.org/stable/docs/aggregation/index.html#torchjd.aggregation.Aggregator).
[`Aggregator`](https://torchjd.org/stable/docs/aggregation#torchjd.aggregation.Aggregator).
Whenever you want to optimize the vector of per-sample losses, you should rather use the
[`torchjd.autogram.Engine`](https://torchjd.org/stable/docs/autogram/engine.html). Instead of
[`torchjd.autogram.Engine`](https://torchjd.org/stable/docs/autogram/engine/). Instead of
computing the full Jacobian at once, it computes the Gramian of this Jacobian, layer by layer, in a
memory-efficient way. A vector of weights (one per element of the batch) can then be extracted from
this Gramian, using a
[`Weighting`](https://torchjd.org/stable/docs/aggregation/index.html#torchjd.aggregation.Weighting),
[`Weighting`](https://torchjd.org/stable/docs/aggregation#torchjd.aggregation.Weighting),
and used to combine the losses of the batch. Assuming each element of the batch is
processed independently from the others, this approach is equivalent to
[`torchjd.autojac.backward`](https://torchjd.org/stable/docs/autojac/backward/) while being
Expand Down Expand Up @@ -210,7 +210,7 @@ for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
> [!NOTE]
> Here, because the losses are a matrix instead of a simple vector, we compute a *generalized
> Gramian* and we extract weights from it using a
> [GeneralizedWeighting](https://torchjd.org/docs/aggregation/index.html#torchjd.aggregation.GeneralizedWeighting).
> [GeneralizedWeighting](https://torchjd.org/stable/docs/aggregation/#torchjd.aggregation.GeneralizedWeighting).

More usage examples can be found [here](https://torchjd.org/stable/examples/).

Expand All @@ -220,7 +220,7 @@ TorchJD provides many existing aggregators from the literature, listed in the fo
<!-- recommended aggregators first, then alphabetical order -->
| Aggregator | Weighting | Publication |
|------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad.html#torchjd.aggregation.UPGrad) (recommended) | [UPGradWeighting](https://torchjd.org/stable/docs/aggregation/upgrad#torchjd.aggregation.UPGradWeighting) | [Jacobian Descent For Multi-Objective Optimization](https://arxiv.org/pdf/2406.16232) |
| [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/#torchjd.aggregation.UPGrad) (recommended) | [UPGradWeighting](https://torchjd.org/stable/docs/aggregation/upgrad/#torchjd.aggregation.UPGradWeighting) | [Jacobian Descent For Multi-Objective Optimization](https://arxiv.org/pdf/2406.16232) |
| [AlignedMTL](https://torchjd.org/stable/docs/aggregation/aligned_mtl#torchjd.aggregation.AlignedMTL) | [AlignedMTLWeighting](https://torchjd.org/stable/docs/aggregation/aligned_mtl#torchjd.aggregation.AlignedMTLWeighting) | [Independent Component Alignment for Multi-Task Learning](https://arxiv.org/pdf/2305.19000) |
| [CAGrad](https://torchjd.org/stable/docs/aggregation/cagrad#torchjd.aggregation.CAGrad) | [CAGradWeighting](https://torchjd.org/stable/docs/aggregation/cagrad#torchjd.aggregation.CAGradWeighting) | [Conflict-Averse Gradient Descent for Multi-task Learning](https://arxiv.org/pdf/2110.14048) |
| [ConFIG](https://torchjd.org/stable/docs/aggregation/config#torchjd.aggregation.ConFIG) | - | [ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks](https://arxiv.org/pdf/2408.11104) |
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_aligned_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class AlignedMTL(GramianWeightedAggregator):
uses the mean eigenvalue (as in the original implementation).

.. note::
This implementation was adapted from the `official implementation
<https://github.com/SamsungLabs/MTL/tree/master/code/optim/aligned>`_.
This implementation was adapted from the official implementation of SamsungLabs/MTL,
which is not available anymore at the time of writing.
"""

def __init__(
Expand Down
5 changes: 3 additions & 2 deletions src/torchjd/aggregation/_mgda.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ class MGDA(GramianWeightedAggregator):
r"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` performing the gradient aggregation
step of `Multiple-gradient descent algorithm (MGDA) for multiobjective optimization
<https://www.sciencedirect.com/science/article/pii/S1631073X12000738>`_. The implementation is
based on Algorithm 2 of `Multi-Task Learning as Multi-Objective Optimization
<https://comptes-rendus.academie-sciences.fr/mathematique/articles/10.1016/j.crma.2012.03.014/>`_.
The implementation is based on Algorithm 2 of `Multi-Task Learning as Multi-Objective
Optimization
<https://proceedings.neurips.cc/paper_files/paper/2018/file/432aca3a1e345e339f35a30c8f65edce-Paper.pdf>`_.

:param epsilon: The value of :math:`\hat{\gamma}` below which we stop the optimization.
Expand Down