Skip to content

feat(aggregation): Add GradVac aggregator#638

Open
rkhosrowshahi wants to merge 11 commits intoSimplexLab:mainfrom
rkhosrowshahi:feature/gradvac
Open

feat(aggregation): Add GradVac aggregator#638
rkhosrowshahi wants to merge 11 commits intoSimplexLab:mainfrom
rkhosrowshahi:feature/gradvac

Conversation

@rkhosrowshahi
Copy link
Copy Markdown
Contributor

@rkhosrowshahi rkhosrowshahi commented Apr 9, 2026

Summary

Adds Gradient Vaccine (GradVac) from ICLR 2021 as a stateful Aggregator on the full task Jacobian.

Behavior

  • Per-block cosine statistics and EMA targets \bar{\rho}, with the closed-form vaccine update when \rho < \bar{\rho}.
  • group_type: 0 whole model (single block); 1 all_layer via encoder (leaf modules with parameters); 2 all_matrix via shared_params (one block per tensor, iteration order = Jacobian column order).
  • DEFAULT_GRADVAC_EPS and configurable eps (constructor + mutable attribute).
  • Autogram not supported (needs full rows and per-block inner products). Task shuffle uses torch.randperm; use torch.manual_seed for reproducibility.

Files

  • src/torchjd/aggregation/_gradvac.py, export in __init__.py
  • docs/source/docs/aggregation/gradvac.rst + index toctree
  • tests/unit/aggregation/test_gradvac.py

Verification

  • ruff format / ruff check on touched paths
  • ty check on _gradvac.py
  • pytest tests/unit/aggregation/test_gradvac.py tests/unit/aggregation/test_values.py -W error
  • Sphinx HTML build for docs

Implement Gradient Vaccine (ICLR 2021) as a stateful Jacobian aggregator.
Support group_type 0 (whole model), 1 (all_layer via encoder), and 2
(all_matrix via shared_params), with DEFAULT_GRADVAC_EPS and configurable
eps. Add Sphinx page and unit tests. Autogram is not supported; use
torch.manual_seed for reproducible task shuffle order.

Made-with: Cursor
@ValerianRey ValerianRey added cc: feat Conventional commit type for new features. package: aggregation labels Apr 9, 2026
@rkhosrowshahi
Copy link
Copy Markdown
Contributor Author

Hey @ValerianRey :
Could you please add cc: feat and package: aggregation labels to this PR?

@github-actions github-actions bot changed the title feat(aggregation): add GradVac aggregator feat(aggregation): Add GradVac aggregator Apr 9, 2026
@ValerianRey
Copy link
Copy Markdown
Contributor

I think the link check action failed because of bad luck, re-running the job fixed it.

@ValerianRey
Copy link
Copy Markdown
Contributor

Thanks a lot for the PR, I'll look at it tonight or tomorrow!

@ValerianRey
Copy link
Copy Markdown
Contributor

/opencode:Plan can you please check that the actual implementation (from line 219 to line 251 of _gradvac.py) is equivalent to the implementation in https://github.com/median-research-group/LibMTL/blob/main/LibMTL/weighting/GradVac.py? Please report anything that differs.

@opencode-agent
Copy link
Copy Markdown
Contributor

Let me analyze the differences between the two implementations:

Comparison Summary

After comparing lines 219-251 of _gradvac.py with the LibMTL reference implementation, here are the findings:

✅ Equivalent Behaviors

  1. Core Algorithm: Both implement the same GradVac algorithm:

    • Clone the gradients to pc_grads
    • Iterate over task pairs (i, j) where i ≠ j
    • Random shuffle of j for each i
    • Compute cosine similarity rho_ijk per block
    • Apply vaccine update when rho_ijk < bar (EMA target)
    • Update EMA: bar = (1-beta)*bar + beta*rho_ijk
  2. Weight Formula: Mathematically identical:

    • LibMTL: ||g_i||*(ρ̄*√(1-ρ²)-ρ*√(1-ρ̄²))/(||g_j||*√(1-ρ̄²)+ε)
    • TorchJD: norm_i*(bar*sqrt_1_rho2-rho_ijk*sqrt_1_bar2)/(norm_j*sqrt_1_bar2+eps)
  3. Group Types: Both support the same three grouping strategies (whole_model, all_layer, all_matrix)

🔍 Differences Found

Aspect LibMTL TorchJD (this PR)
RNG for shuffling random.shuffle() (Python stdlib) torch.randperm() (PyTorch)
Epsilon Hardcoded 1e-8 Configurable eps attribute
Numerical safety No clamp before sqrt() .clamp(min=0.0) before sqrt()
Block indexing Runtime sum(self.k_idx[:k]) Precomputed offsets list
Whole model case Special case with k_idx=[-1] Natural via sizes=(n,)
Return value batch_weight (unused) pc_grads.sum(dim=0)
State management self.step counter _state_key tuple

📝 Notes

  1. The clamp operation in TorchJD ((1.0 - rho_ijk * rho_ijk).clamp(min=0.0).sqrt()) is a numerical safety improvement not present in LibMTL - it prevents potential NaN from negative values due to floating-point errors.

  2. The random number generator difference is documented in the TorchJD docstring: "Task-order shuffling uses the global PyTorch RNG".

  3. Both implementations correctly handle the asymmetric nature of the algorithm (updating pc_grads[i] based on comparisons with grads[j]).

Conclusion: The implementations are algorithmically equivalent. TorchJD includes minor improvements (configurable eps, numerical clamp) and follows the codebase's design patterns.

New%20session%20-%202026-04-11T12%3A55%3A53.983Z
opencode session  |  github run

@ValerianRey
Copy link
Copy Markdown
Contributor

/opencode:Plan Please review this PR in detail.

@opencode-agent

This comment was marked as low quality.

Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really good work, tyvm! This seems almost ready to merge IMO. Just a bunch of nitpicks that should be really easy to fix, and 1 somewhat important implementation question (iterating over groups in outer loop vs inner loop).

I would also argue that it seems that this could be implemented as a Weighting internally (because we actually act on norms and cosine similarities between gradients, which is what the gramian contains). Also, it's possible to keep track of norms and cosine similarities between projected gradients even if we don't have those gradients, just by making some operations on the gramian. This is what we did to implement PCGrad as a Weighting.

For example, imagine you have g1 and g2 be two gradients. From the gramian, you know ||g1||, ||g2|| (the sqrt of the diag elements), and g1 . g2 (an off-diag element), so you can deduce cos(g1, g2) from that.

If you compute g1' = g1 + w * g2, you can also directly deduce the norm of g1':
||g1'||² = ||g1||² + w² ||g2||² + 2w g1 . g2 (all elements of the right handside are known).

Similarly, you can compute g1' . g2 = (g1 + w * g2) . g2 = g1 . g2 + w g1 . g2.

So even after projection, you still know the dot products between all of your gradients, meaning that you still know the "new" gramian.

I didn't think through it entirely but at a first glance it seems possible to adapt this as a weighting, because of that. The implementation may even be faster actually (because we have fewer norms to recompute). But it may be hard to implement, so IMO we should merge this without even trying to implement it as a Weighting, and we can always improve later. @PierreQuinton what do you think about that?

@ValerianRey

This comment was marked as resolved.

@rkhosrowshahi
Copy link
Copy Markdown
Contributor Author

Opencode's review was quite low quality, but it mentioned something that I missed: we need a test for GradVac in tests/unit/aggregation/test_values.py.

Similarly, i'd like to have GradVac added to tests/plots/interactive_plotter.py.

Thanks. I added the GradVac to the code and improved the code a bit to be more user-friendly. See the PCGrad and GradVac in the plot, find the same aggregated gradient. If you liked the changes, I can add to the commit as well.
PCGrad vs. GradVac

@ValerianRey

This comment was marked as outdated.

@PierreQuinton

This comment was marked as resolved.

- Use group_type "whole_model" | "all_layer" | "all_matrix" instead of 0/1/2
- Remove DEFAULT_GRADVAC_EPS from the public API; keep default 1e-8; allow eps=0
- Validate beta via setter; tighten GradVac repr/str expectations
- Fix all_layer leaf sizing via children() and parameters() instead of private fields
- Trim redundant GradVac.rst prose; align docs with the new API
- Tests: GradVac cases, value regression with torch.manual_seed for GradVac
- Plotter: factory dict + fresh aggregator instances per update; legend from
  selected keys; MathJax labels and live angle/length readouts in the sidebar

This commit includes GradVac implementation with Aggregator class.
@rkhosrowshahi rkhosrowshahi requested a review from a team as a code owner April 12, 2026 16:54
…hting

GradVac only needs gradient norms and dot products, which are fully
determined by the Gramian. This makes GradVac compatible with the autogram path.

- Remove grouping parameters (group_type, encoder, shared_params) from GradVac
- Export GradVacWeighting publicly
@ValerianRey
Copy link
Copy Markdown
Contributor

ValerianRey commented Apr 12, 2026

I think this is ready to merge, except for some plotting things. Can we remove the changes to the plotter and make plotter improvements in a different PR (except adding GradVac to the list of aggregators in the plotter)? I see a few issues in the plotter changes, and I'd rather merge this PR now and make the rest of the changes in a different PR. @rkhosrowshahi

BTW the link check action will fail because the links I added in the readme point to some documentation that will only be built after we merge this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: feat Conventional commit type for new features. package: aggregation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants