From 32ef75e177782446c2af0a592a11ec6e3591d168 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 20 Oct 2025 12:59:36 +0200 Subject: [PATCH 001/182] Add `DiagonalSparseTensor` with the default fallback to dense mechanism. --- .../autogram/diagonal_sparse_tensor.py | 103 ++++++++++++++++++ .../autogram/test_diagonal_sparse_tensor.py | 34 ++++++ 2 files changed, 137 insertions(+) create mode 100644 src/torchjd/autogram/diagonal_sparse_tensor.py create mode 100644 tests/unit/autogram/test_diagonal_sparse_tensor.py diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py new file mode 100644 index 000000000..333661ce4 --- /dev/null +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -0,0 +1,103 @@ +import torch +from torch import Tensor +from torch.utils._pytree import tree_map + + +class DiagonalSparseTensor(torch.Tensor): + + @staticmethod + def __new__(cls, data: Tensor, v_to_p: list[int]): + # At the moment, this class is not compositional, so we assert + # that the tensor we're wrapping is exactly a Tensor + assert type(data) is Tensor + + # Note [Passing requires_grad=true tensors to subclasses] + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Calling _make_subclass directly in an autograd context is + # never the right thing to do, as this will detach you from + # the autograd graph. You must create an autograd function + # representing the "constructor" (NegativeView, in this case) + # and call that instead. This assert helps prevent direct usage + # (which is bad!) + assert not data.requires_grad or not torch.is_grad_enabled() + + # There is something very subtle going on here. In particular, + # suppose that elem is a view. Does all of the view metadata + # (sizes, strides, storages) get propagated correctly? Yes! + # Internally, the way _make_subclass works is it creates an + # alias (using Tensor.alias) of the original tensor, which + # means we replicate storage/strides, but with the Python object + # as an instance of your subclass. In other words, + # _make_subclass is the "easy" case of metadata propagation, + # because anything that alias() propagates, you will get in + # your subclass. It is _make_wrapper_subclass which is + # problematic... + # + # TODO: We need to think about how we want to turn this into + # official API. I am thinking that something that does the + # assert above and this call could be made into a utility function + # that is in the public API + return Tensor._make_wrapper_subclass( + cls, [data.shape[i] for i in v_to_p], dtype=data.dtype, device=data.device + ) + + def __init__(self, data: Tensor, v_to_p: list[int]): + """ + Represent a diagonal sparse tensor. + + :param data: The physical contiguous data. + :param v_to_p: Maps virtual dimensions to physical dimensions. + + An example is `data` of shape `[d_1, d_2, d_3]` and `v_to_p` equal to `[0, 1, 0, 2, 1]` + means the virtual shape is `[d_1, d_2, d_1, d_3, d_2]` and the represented Tensor, indexed + at `[i, j, k, l, m]` is `0.` unless `i==k` and `j==m`. + """ + # Deliberate omission of `super().__init__()` as we have an unfaithful data. + self._data = data + self._v_to_p = v_to_p + self._v_shape = tuple(data.shape[i] for i in v_to_p) + + def to_dense(self) -> Tensor: + first_indices = dict[int, int]() + identity_matrices = dict[int, Tensor]() + einsum_args: list[Tensor | list[int]] = [self._data, list(range(self._data.ndim))] + output_indices = list(range(len(self._v_to_p))) + for i, j in enumerate(self._v_to_p): + if j not in first_indices: + first_indices[j] = i + else: + if j not in identity_matrices: + device = self._data.device + dtype = self._data.dtype + identity_matrices[j] = torch.eye(self._v_shape[i], device=device, dtype=dtype) + einsum_args += [identity_matrices[j], [first_indices[j], i]] + + output = torch.einsum(*einsum_args, output_indices) + return output + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + kwargs = kwargs if kwargs else {} + + # TODO: Handle batched operations (apply to self._data and wrap) + # TODO: Handle all operations that can be represented with an einsum by translating them + # to operations on self._data and wrapping accordingly. + + # --- Fallback: Fold to Dense Tensor --- + def unwrap_to_dense(t): + if isinstance(t, cls): + return t.to_dense() + else: + return t + + print(f"Falling back to dense for {func.__name__}...") + return func(*tree_map(unwrap_to_dense, args), **tree_map(unwrap_to_dense, kwargs)) + + def __repr__(self): + return ( + f"DiagonalSparseTensor(\n" + f" data={self._data},\n" + f" v_to_p_map={self._v_to_p},\n" + f" shape={self._v_shape}\n" + f")" + ) diff --git a/tests/unit/autogram/test_diagonal_sparse_tensor.py b/tests/unit/autogram/test_diagonal_sparse_tensor.py new file mode 100644 index 000000000..38eb0d80a --- /dev/null +++ b/tests/unit/autogram/test_diagonal_sparse_tensor.py @@ -0,0 +1,34 @@ +import torch +from pytest import mark +from torch.testing import assert_close + +from torchjd.autogram.diagonal_sparse_tensor import DiagonalSparseTensor + + +@mark.parametrize( + "shape", + [ + [], + [1], + [3], + [1, 1], + [1, 4], + [3, 1], + [1, 2, 3], + ], +) +def test_diagonal_spase_tensor_scalar(shape: list[int]): + a = torch.randn(shape) + b = DiagonalSparseTensor(a, list(range(len(shape)))) + + assert_close(a, b) + + +@mark.parametrize("dim", [1, 2, 3, 4, 5, 10]) +def test_diag_equivalence(dim: int): + a = torch.randn([dim]) + b = DiagonalSparseTensor(a, [0, 0]) + + diag_a = torch.diag(a) + + assert_close(b, diag_a) From 5c76d69d4a614392470ae8ff85bf43c9297e95a1 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 20 Oct 2025 19:19:27 +0200 Subject: [PATCH 002/182] Ignore mypy --- .../autogram/diagonal_sparse_tensor.py | 24 ++++--------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 333661ce4..47028493f 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -37,25 +37,11 @@ def __new__(cls, data: Tensor, v_to_p: list[int]): # official API. I am thinking that something that does the # assert above and this call could be made into a utility function # that is in the public API - return Tensor._make_wrapper_subclass( - cls, [data.shape[i] for i in v_to_p], dtype=data.dtype, device=data.device - ) - - def __init__(self, data: Tensor, v_to_p: list[int]): - """ - Represent a diagonal sparse tensor. - - :param data: The physical contiguous data. - :param v_to_p: Maps virtual dimensions to physical dimensions. - - An example is `data` of shape `[d_1, d_2, d_3]` and `v_to_p` equal to `[0, 1, 0, 2, 1]` - means the virtual shape is `[d_1, d_2, d_1, d_3, d_2]` and the represented Tensor, indexed - at `[i, j, k, l, m]` is `0.` unless `i==k` and `j==m`. - """ - # Deliberate omission of `super().__init__()` as we have an unfaithful data. - self._data = data - self._v_to_p = v_to_p - self._v_shape = tuple(data.shape[i] for i in v_to_p) + shape = [data.shape[i] for i in v_to_p] + result = Tensor._make_wrapper_subclass(cls, shape, dtype=data.dtype, device=data.device) + result._data = data # type: ignore + result._v_to_p = v_to_p # type: ignore + result._v_shape = shape # type: ignore def to_dense(self) -> Tensor: first_indices = dict[int, int]() From 34f4dce706ebfa4670209364cebc02fbd87fcddd Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 20 Oct 2025 19:20:04 +0200 Subject: [PATCH 003/182] Remove useless comment. --- src/torchjd/autogram/diagonal_sparse_tensor.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 47028493f..8ae47c8ac 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -21,22 +21,6 @@ def __new__(cls, data: Tensor, v_to_p: list[int]): # (which is bad!) assert not data.requires_grad or not torch.is_grad_enabled() - # There is something very subtle going on here. In particular, - # suppose that elem is a view. Does all of the view metadata - # (sizes, strides, storages) get propagated correctly? Yes! - # Internally, the way _make_subclass works is it creates an - # alias (using Tensor.alias) of the original tensor, which - # means we replicate storage/strides, but with the Python object - # as an instance of your subclass. In other words, - # _make_subclass is the "easy" case of metadata propagation, - # because anything that alias() propagates, you will get in - # your subclass. It is _make_wrapper_subclass which is - # problematic... - # - # TODO: We need to think about how we want to turn this into - # official API. I am thinking that something that does the - # assert above and this call could be made into a utility function - # that is in the public API shape = [data.shape[i] for i in v_to_p] result = Tensor._make_wrapper_subclass(cls, shape, dtype=data.dtype, device=data.device) result._data = data # type: ignore From a0b7ffc941701c5708fd7301f658c942142276d8 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 20 Oct 2025 19:21:20 +0200 Subject: [PATCH 004/182] Change repr --- src/torchjd/autogram/diagonal_sparse_tensor.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 8ae47c8ac..525aa2dc3 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -65,9 +65,6 @@ def unwrap_to_dense(t): def __repr__(self): return ( - f"DiagonalSparseTensor(\n" - f" data={self._data},\n" - f" v_to_p_map={self._v_to_p},\n" - f" shape={self._v_shape}\n" - f")" + f"DiagonalSparseTensor(data={self._data}, v_to_p_map={self._v_to_p}, shape=" + f"{self._v_shape})" ) From f476b2975c28280df7b7d1e85f4fbadc710ced8d Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 20 Oct 2025 19:26:35 +0200 Subject: [PATCH 005/182] revert removing `__init__` --- src/torchjd/autogram/diagonal_sparse_tensor.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 525aa2dc3..39886be4e 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -1,3 +1,5 @@ +from typing import Any + import torch from torch import Tensor from torch.utils._pytree import tree_map @@ -22,10 +24,12 @@ def __new__(cls, data: Tensor, v_to_p: list[int]): assert not data.requires_grad or not torch.is_grad_enabled() shape = [data.shape[i] for i in v_to_p] - result = Tensor._make_wrapper_subclass(cls, shape, dtype=data.dtype, device=data.device) - result._data = data # type: ignore - result._v_to_p = v_to_p # type: ignore - result._v_shape = shape # type: ignore + return Tensor._make_wrapper_subclass(cls, shape, dtype=data.dtype, device=data.device) + + def __init__(self, data: Tensor, v_to_p: list[int]): + self._data = data + self._v_to_p = v_to_p + self._v_shape = [data.shape[i] for i in v_to_p] def to_dense(self) -> Tensor: first_indices = dict[int, int]() @@ -46,7 +50,9 @@ def to_dense(self) -> Tensor: return output @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def __torch_dispatch__( + cls, func: {__name__}, types: Any, args: tuple[()] | Any = (), kwargs: Any = None + ): kwargs = kwargs if kwargs else {} # TODO: Handle batched operations (apply to self._data and wrap) @@ -54,7 +60,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # to operations on self._data and wrapping accordingly. # --- Fallback: Fold to Dense Tensor --- - def unwrap_to_dense(t): + def unwrap_to_dense(t: Tensor): if isinstance(t, cls): return t.to_dense() else: From 447d714d972798a78eb0459aebb1cf48aa0262c4 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 21 Oct 2025 10:12:12 +0200 Subject: [PATCH 006/182] Give implementation for pointwise --- .../autogram/diagonal_sparse_tensor.py | 79 ++++++++++++++++++- .../autogram/test_diagonal_sparse_tensor.py | 32 +++++++- 2 files changed, 104 insertions(+), 7 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 39886be4e..99bdd357c 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -2,8 +2,70 @@ import torch from torch import Tensor +from torch.ops import aten from torch.utils._pytree import tree_map +# pointwise functions applied to one Tensor with `0.0 → 0` +_pointwise_functions = { + aten.abs.default, + aten.abs_.default, + aten.absolute.default, + aten.absolute_.default, + aten.neg.default, + aten.neg_.default, + aten.negative.default, + aten.negative_.default, + aten.sign.default, + aten.sign_.default, + aten.sgn.default, + aten.sgn_.default, + aten.square.default, + aten.square_.default, + aten.fix.default, + aten.fix_.default, + aten.floor.default, + aten.floor_.default, + aten.ceil.default, + aten.ceil_.default, + aten.trunc.default, + aten.trunc_.default, + aten.round.default, + aten.round_.default, + aten.positive.default, + aten.expm1.default, + aten.expm1_.default, + aten.log1p.default, + aten.log1p_.default, + aten.sqrt.default, + aten.sqrt_.default, + aten.sin.default, + aten.sin_.default, + aten.tan.default, + aten.tan_.default, + aten.sinh.default, + aten.sinh_.default, + aten.tanh.default, + aten.tanh_.default, + aten.asin.default, + aten.asin_.default, + aten.atan.default, + aten.atan_.default, + aten.asinh.default, + aten.asinh_.default, + aten.atanh.default, + aten.atanh_.default, + aten.erf.default, + aten.erf_.default, + aten.erfinv.default, + aten.erfinv_.default, + aten.relu.default, + aten.relu_.default, + aten.hardtanh.default, + aten.hardtanh_.default, + aten.leaky_relu.default, + aten.leaky_relu_.default, +} + class DiagonalSparseTensor(torch.Tensor): @@ -50,10 +112,19 @@ def to_dense(self) -> Tensor: return output @classmethod - def __torch_dispatch__( - cls, func: {__name__}, types: Any, args: tuple[()] | Any = (), kwargs: Any = None - ): - kwargs = kwargs if kwargs else {} + def __torch_dispatch__(cls, func: {__name__}, types: Any, args: tuple = (), kwargs: Any = None): + kwargs = {} if kwargs is None else kwargs + + # If `func` is a pointwise operator that applies to a single Tensor and such that func(0)=0 + # Then we can apply the transformation to self._data and wrap the result. + if func in _pointwise_functions: + assert ( + isinstance(args, tuple) and len(args) == 1 and func(torch.zeros([])).item() == 0.0 + ) + sparse_tensor = args[0] + assert isinstance(sparse_tensor, DiagonalSparseTensor) + new_data = func(sparse_tensor._data) + return DiagonalSparseTensor(new_data, sparse_tensor._v_to_p) # TODO: Handle batched operations (apply to self._data and wrap) # TODO: Handle all operations that can be represented with an einsum by translating them diff --git a/tests/unit/autogram/test_diagonal_sparse_tensor.py b/tests/unit/autogram/test_diagonal_sparse_tensor.py index 38eb0d80a..1f51721ca 100644 --- a/tests/unit/autogram/test_diagonal_sparse_tensor.py +++ b/tests/unit/autogram/test_diagonal_sparse_tensor.py @@ -1,8 +1,9 @@ import torch from pytest import mark from torch.testing import assert_close +from utils.tensors import randn_, zeros_ -from torchjd.autogram.diagonal_sparse_tensor import DiagonalSparseTensor +from torchjd.autogram.diagonal_sparse_tensor import DiagonalSparseTensor, _pointwise_functions @mark.parametrize( @@ -18,7 +19,7 @@ ], ) def test_diagonal_spase_tensor_scalar(shape: list[int]): - a = torch.randn(shape) + a = randn_(shape) b = DiagonalSparseTensor(a, list(range(len(shape)))) assert_close(a, b) @@ -26,9 +27,34 @@ def test_diagonal_spase_tensor_scalar(shape: list[int]): @mark.parametrize("dim", [1, 2, 3, 4, 5, 10]) def test_diag_equivalence(dim: int): - a = torch.randn([dim]) + a = randn_([dim]) b = DiagonalSparseTensor(a, [0, 0]) diag_a = torch.diag(a) assert_close(b, diag_a) + + +def test_three_virtual_single_physical(): + dim = 10 + a = randn_([dim]) + b = DiagonalSparseTensor(a, [0, 0, 0]) + + expected = zeros_([dim, dim, dim]) + for i in range(dim): + expected[i, i, i] = a[i] + + assert_close(b, expected) + + +@mark.parametrize("func", _pointwise_functions) +def test_pointwise(func): + dim = 100 + a = randn_([dim]) + b = DiagonalSparseTensor(a, [0, 0]) + c = b.to_dense() + d = func(b) + assert isinstance(d, DiagonalSparseTensor) + + # need to be careful about nans + assert_close(d, func(c)) From 85556a8083279db0e1bb87716934860b9edac931 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 21 Oct 2025 10:36:59 +0200 Subject: [PATCH 007/182] Add decorator to handle other functions. Add two examples of such functions (mean and sum). --- .../autogram/diagonal_sparse_tensor.py | 38 ++++++++++++++++--- .../autogram/test_diagonal_sparse_tensor.py | 23 ++++++++--- 2 files changed, 50 insertions(+), 11 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 99bdd357c..7ec1de3ad 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -6,7 +6,7 @@ from torch.utils._pytree import tree_map # pointwise functions applied to one Tensor with `0.0 → 0` -_pointwise_functions = { +_POINTWISE_FUNCTIONS = { aten.abs.default, aten.abs_.default, aten.absolute.default, @@ -65,6 +65,19 @@ aten.leaky_relu.default, aten.leaky_relu_.default, } +_HANDLED_FUNCTIONS = dict() +import functools + + +def implements(torch_function): + """Register a torch function override for ScalarTensor""" + + def decorator(func): + functools.update_wrapper(func, torch_function) + _HANDLED_FUNCTIONS[torch_function] = func + return func + + return decorator class DiagonalSparseTensor(torch.Tensor): @@ -85,6 +98,10 @@ def __new__(cls, data: Tensor, v_to_p: list[int]): # (which is bad!) assert not data.requires_grad or not torch.is_grad_enabled() + # TODO: assert a minimal data, all of its dimensions must be used at least once + # TODO: If no repeat in v_to_p, return a view of data (non sparse tensor). If this cannot be + # done in __new__, create a helper function for that, and use this one everywhere. + shape = [data.shape[i] for i in v_to_p] return Tensor._make_wrapper_subclass(cls, shape, dtype=data.dtype, device=data.device) @@ -117,7 +134,7 @@ def __torch_dispatch__(cls, func: {__name__}, types: Any, args: tuple = (), kwar # If `func` is a pointwise operator that applies to a single Tensor and such that func(0)=0 # Then we can apply the transformation to self._data and wrap the result. - if func in _pointwise_functions: + if func in _POINTWISE_FUNCTIONS: assert ( isinstance(args, tuple) and len(args) == 1 and func(torch.zeros([])).item() == 0.0 ) @@ -126,9 +143,8 @@ def __torch_dispatch__(cls, func: {__name__}, types: Any, args: tuple = (), kwar new_data = func(sparse_tensor._data) return DiagonalSparseTensor(new_data, sparse_tensor._v_to_p) - # TODO: Handle batched operations (apply to self._data and wrap) - # TODO: Handle all operations that can be represented with an einsum by translating them - # to operations on self._data and wrapping accordingly. + if func in _HANDLED_FUNCTIONS: + return _HANDLED_FUNCTIONS[func](*args, **kwargs) # --- Fallback: Fold to Dense Tensor --- def unwrap_to_dense(t: Tensor): @@ -145,3 +161,15 @@ def __repr__(self): f"DiagonalSparseTensor(data={self._data}, v_to_p_map={self._v_to_p}, shape=" f"{self._v_shape})" ) + + +@implements(aten.mean.default) +def mean_default(t: Tensor) -> Tensor: + assert isinstance(t, DiagonalSparseTensor) + return aten.sum.default(t._data) / t.numel() + + +@implements(aten.sum.default) +def sum_default(t: Tensor) -> Tensor: + assert isinstance(t, DiagonalSparseTensor) + return aten.sum.default(t._data) diff --git a/tests/unit/autogram/test_diagonal_sparse_tensor.py b/tests/unit/autogram/test_diagonal_sparse_tensor.py index 1f51721ca..d72bc5404 100644 --- a/tests/unit/autogram/test_diagonal_sparse_tensor.py +++ b/tests/unit/autogram/test_diagonal_sparse_tensor.py @@ -3,7 +3,7 @@ from torch.testing import assert_close from utils.tensors import randn_, zeros_ -from torchjd.autogram.diagonal_sparse_tensor import DiagonalSparseTensor, _pointwise_functions +from torchjd.autogram.diagonal_sparse_tensor import _POINTWISE_FUNCTIONS, DiagonalSparseTensor @mark.parametrize( @@ -47,14 +47,25 @@ def test_three_virtual_single_physical(): assert_close(b, expected) -@mark.parametrize("func", _pointwise_functions) +@mark.parametrize("func", _POINTWISE_FUNCTIONS) def test_pointwise(func): - dim = 100 + dim = 10 a = randn_([dim]) b = DiagonalSparseTensor(a, [0, 0]) c = b.to_dense() - d = func(b) - assert isinstance(d, DiagonalSparseTensor) + res = func(b) + assert isinstance(res, DiagonalSparseTensor) # need to be careful about nans - assert_close(d, func(c)) + assert_close(res, func(c)) + + +@mark.parametrize("func", [torch.mean, torch.sum]) +def test_mean(func): + dim = 10 + a = randn_([dim]) + b = DiagonalSparseTensor(a, [0, 0]) + c = b.to_dense() + + mean = func(b) + assert_close(mean, func(c)) From c5f868c9318304eb1811e841a346fd3471f8a330 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 21 Oct 2025 10:37:32 +0200 Subject: [PATCH 008/182] Improve naming. --- tests/unit/autogram/test_diagonal_sparse_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/autogram/test_diagonal_sparse_tensor.py b/tests/unit/autogram/test_diagonal_sparse_tensor.py index d72bc5404..7e9ad78f3 100644 --- a/tests/unit/autogram/test_diagonal_sparse_tensor.py +++ b/tests/unit/autogram/test_diagonal_sparse_tensor.py @@ -67,5 +67,5 @@ def test_mean(func): b = DiagonalSparseTensor(a, [0, 0]) c = b.to_dense() - mean = func(b) - assert_close(mean, func(c)) + res = func(b) + assert_close(res, func(c)) From efa80194a1084b818ca07442b7f493d4e04613d0 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 22 Oct 2025 16:07:21 +0200 Subject: [PATCH 009/182] improve --- src/torchjd/autogram/diagonal_sparse_tensor.py | 1 + tests/unit/autogram/test_diagonal_sparse_tensor.py | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 7ec1de3ad..7cc203972 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -125,6 +125,7 @@ def to_dense(self) -> Tensor: identity_matrices[j] = torch.eye(self._v_shape[i], device=device, dtype=dtype) einsum_args += [identity_matrices[j], [first_indices[j], i]] + # Need to be careful about nans, we would want to get identity times nan. output = torch.einsum(*einsum_args, output_indices) return output diff --git a/tests/unit/autogram/test_diagonal_sparse_tensor.py b/tests/unit/autogram/test_diagonal_sparse_tensor.py index 7e9ad78f3..cfa5368cc 100644 --- a/tests/unit/autogram/test_diagonal_sparse_tensor.py +++ b/tests/unit/autogram/test_diagonal_sparse_tensor.py @@ -56,12 +56,11 @@ def test_pointwise(func): res = func(b) assert isinstance(res, DiagonalSparseTensor) - # need to be careful about nans - assert_close(res, func(c)) + assert_close(res, func(c), equal_nan=True) @mark.parametrize("func", [torch.mean, torch.sum]) -def test_mean(func): +def test_unary(func): dim = 10 a = randn_([dim]) b = DiagonalSparseTensor(a, [0, 0]) From e36b3c5acd96c60461b7e0318c7dcb63b9e34f54 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 23 Oct 2025 11:38:11 +0200 Subject: [PATCH 010/182] Remove inplace functions from the list of pointwise functions (they should be implemented differently) --- .../autogram/diagonal_sparse_tensor.py | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 7cc203972..e5da3ab2c 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -8,62 +8,34 @@ # pointwise functions applied to one Tensor with `0.0 → 0` _POINTWISE_FUNCTIONS = { aten.abs.default, - aten.abs_.default, aten.absolute.default, - aten.absolute_.default, aten.neg.default, - aten.neg_.default, aten.negative.default, - aten.negative_.default, aten.sign.default, - aten.sign_.default, aten.sgn.default, - aten.sgn_.default, aten.square.default, - aten.square_.default, aten.fix.default, - aten.fix_.default, aten.floor.default, - aten.floor_.default, aten.ceil.default, - aten.ceil_.default, aten.trunc.default, - aten.trunc_.default, aten.round.default, - aten.round_.default, aten.positive.default, aten.expm1.default, - aten.expm1_.default, aten.log1p.default, - aten.log1p_.default, aten.sqrt.default, - aten.sqrt_.default, aten.sin.default, - aten.sin_.default, aten.tan.default, - aten.tan_.default, aten.sinh.default, - aten.sinh_.default, aten.tanh.default, - aten.tanh_.default, aten.asin.default, - aten.asin_.default, aten.atan.default, - aten.atan_.default, aten.asinh.default, - aten.asinh_.default, aten.atanh.default, - aten.atanh_.default, aten.erf.default, - aten.erf_.default, aten.erfinv.default, - aten.erfinv_.default, aten.relu.default, - aten.relu_.default, aten.hardtanh.default, - aten.hardtanh_.default, aten.leaky_relu.default, - aten.leaky_relu_.default, } _HANDLED_FUNCTIONS = dict() import functools From 018a9945bcb2e0d20cd8e7ccfb73347df68716ba Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 23 Oct 2025 12:41:05 +0200 Subject: [PATCH 011/182] Fix `to_dense` --- .../autogram/diagonal_sparse_tensor.py | 26 +++++++------------ .../autogram/test_diagonal_sparse_tensor.py | 1 + 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index e5da3ab2c..3791c9864 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -83,23 +83,15 @@ def __init__(self, data: Tensor, v_to_p: list[int]): self._v_shape = [data.shape[i] for i in v_to_p] def to_dense(self) -> Tensor: - first_indices = dict[int, int]() - identity_matrices = dict[int, Tensor]() - einsum_args: list[Tensor | list[int]] = [self._data, list(range(self._data.ndim))] - output_indices = list(range(len(self._v_to_p))) - for i, j in enumerate(self._v_to_p): - if j not in first_indices: - first_indices[j] = i - else: - if j not in identity_matrices: - device = self._data.device - dtype = self._data.dtype - identity_matrices[j] = torch.eye(self._v_shape[i], device=device, dtype=dtype) - einsum_args += [identity_matrices[j], [first_indices[j], i]] - - # Need to be careful about nans, we would want to get identity times nan. - output = torch.einsum(*einsum_args, output_indices) - return output + if self._data.ndim == 0: + return self._data + p_index_ranges = [torch.arange(s, device=self._data.device) for s in self._data.shape] + p_indices_grid = torch.meshgrid(*p_index_ranges) + v_indices_grid = [p_indices_grid[i] for i in self._v_to_p] + + res = torch.zeros(self.shape, device=self._data.device, dtype=self._data.dtype) + res[v_indices_grid] = self._data + return res @classmethod def __torch_dispatch__(cls, func: {__name__}, types: Any, args: tuple = (), kwargs: Any = None): diff --git a/tests/unit/autogram/test_diagonal_sparse_tensor.py b/tests/unit/autogram/test_diagonal_sparse_tensor.py index cfa5368cc..32d43244e 100644 --- a/tests/unit/autogram/test_diagonal_sparse_tensor.py +++ b/tests/unit/autogram/test_diagonal_sparse_tensor.py @@ -15,6 +15,7 @@ [1, 1], [1, 4], [3, 1], + [3, 4], [1, 2, 3], ], ) From e91323cd3df135425bca8f5ed606cd9d6af968dc Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 23 Oct 2025 12:46:55 +0200 Subject: [PATCH 012/182] Verify input. --- src/torchjd/autogram/diagonal_sparse_tensor.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 3791c9864..75e7f9a14 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -74,6 +74,11 @@ def __new__(cls, data: Tensor, v_to_p: list[int]): # TODO: If no repeat in v_to_p, return a view of data (non sparse tensor). If this cannot be # done in __new__, create a helper function for that, and use this one everywhere. + if not all(0 <= i < data.ndim for i in v_to_p): + raise ValueError(f"Elements in v_to_p map to dimensions in data. Found {v_to_p}.") + if len(set(v_to_p)) != data.ndim: + raise ValueError("Every dimension in data must appear at least once in v_to_p.") + shape = [data.shape[i] for i in v_to_p] return Tensor._make_wrapper_subclass(cls, shape, dtype=data.dtype, device=data.device) From f7281539f76699ef306ef586f574013034ae419c Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 23 Oct 2025 12:52:24 +0200 Subject: [PATCH 013/182] Make a builder for DSPs and move checks in it. This should always be used over the constructor of `DST` --- .../autogram/diagonal_sparse_tensor.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 75e7f9a14..be820eb6d 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -70,15 +70,6 @@ def __new__(cls, data: Tensor, v_to_p: list[int]): # (which is bad!) assert not data.requires_grad or not torch.is_grad_enabled() - # TODO: assert a minimal data, all of its dimensions must be used at least once - # TODO: If no repeat in v_to_p, return a view of data (non sparse tensor). If this cannot be - # done in __new__, create a helper function for that, and use this one everywhere. - - if not all(0 <= i < data.ndim for i in v_to_p): - raise ValueError(f"Elements in v_to_p map to dimensions in data. Found {v_to_p}.") - if len(set(v_to_p)) != data.ndim: - raise ValueError("Every dimension in data must appear at least once in v_to_p.") - shape = [data.shape[i] for i in v_to_p] return Tensor._make_wrapper_subclass(cls, shape, dtype=data.dtype, device=data.device) @@ -111,7 +102,7 @@ def __torch_dispatch__(cls, func: {__name__}, types: Any, args: tuple = (), kwar sparse_tensor = args[0] assert isinstance(sparse_tensor, DiagonalSparseTensor) new_data = func(sparse_tensor._data) - return DiagonalSparseTensor(new_data, sparse_tensor._v_to_p) + return diagonal_sparse_tensor(new_data, sparse_tensor._v_to_p) if func in _HANDLED_FUNCTIONS: return _HANDLED_FUNCTIONS[func](*args, **kwargs) @@ -133,6 +124,17 @@ def __repr__(self): ) +def diagonal_sparse_tensor(data: Tensor, v_to_p: list[int]) -> Tensor: + if not all(0 <= i < data.ndim for i in v_to_p): + raise ValueError(f"Elements in v_to_p map to dimensions in data. Found {v_to_p}.") + if len(set(v_to_p)) != data.ndim: + raise ValueError("Every dimension in data must appear at least once in v_to_p.") + if len(v_to_p) == data.ndim: + return torch.movedim(data, (list(range(data.ndim))), v_to_p) + else: + return DiagonalSparseTensor(data, v_to_p) + + @implements(aten.mean.default) def mean_default(t: Tensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) From b2b5d7ae297af4b270d6574d1191a2bd747ac3bb Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 23 Oct 2025 16:09:40 +0200 Subject: [PATCH 014/182] Implement pointwise and inplace pointwise in `_HANDLED_FUNCTIONS`. --- .../autogram/diagonal_sparse_tensor.py | 63 +++++++++++++++---- 1 file changed, 50 insertions(+), 13 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index be820eb6d..cd606eb02 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -37,6 +37,37 @@ aten.hardtanh.default, aten.leaky_relu.default, } +_IN_PLACE_POINTWISE_FUNCTIONS = { + aten.abs_.default, + aten.absolute_.default, + aten.neg_.default, + aten.negative_.default, + aten.sign_.default, + aten.sgn_.default, + aten.square_.default, + aten.fix_.default, + aten.floor_.default, + aten.ceil_.default, + aten.trunc_.default, + aten.round_.default, + aten.positive_.default, + aten.expm1_.default, + aten.log1p_.default, + aten.sqrt_.default, + aten.sin_.default, + aten.tan_.default, + aten.sinh_.default, + aten.tanh_.default, + aten.asin_.default, + aten.atan_.default, + aten.asinh_.default, + aten.atanh_.default, + aten.erf_.default, + aten.erfinv_.default, + aten.relu_.default, + aten.hardtanh_.default, + aten.leaky_relu_.default, +} _HANDLED_FUNCTIONS = dict() import functools @@ -93,17 +124,6 @@ def to_dense(self) -> Tensor: def __torch_dispatch__(cls, func: {__name__}, types: Any, args: tuple = (), kwargs: Any = None): kwargs = {} if kwargs is None else kwargs - # If `func` is a pointwise operator that applies to a single Tensor and such that func(0)=0 - # Then we can apply the transformation to self._data and wrap the result. - if func in _POINTWISE_FUNCTIONS: - assert ( - isinstance(args, tuple) and len(args) == 1 and func(torch.zeros([])).item() == 0.0 - ) - sparse_tensor = args[0] - assert isinstance(sparse_tensor, DiagonalSparseTensor) - new_data = func(sparse_tensor._data) - return diagonal_sparse_tensor(new_data, sparse_tensor._v_to_p) - if func in _HANDLED_FUNCTIONS: return _HANDLED_FUNCTIONS[func](*args, **kwargs) @@ -135,13 +155,30 @@ def diagonal_sparse_tensor(data: Tensor, v_to_p: list[int]) -> Tensor: return DiagonalSparseTensor(data, v_to_p) +for func in _POINTWISE_FUNCTIONS: + + @implements(func) + def func(t: Tensor) -> Tensor: + assert isinstance(t, DiagonalSparseTensor) + return diagonal_sparse_tensor(func(t._data), t._v_to_p) + + +for func in _IN_PLACE_POINTWISE_FUNCTIONS: + + @implements(func) + def func(t: Tensor) -> Tensor: + assert isinstance(t, DiagonalSparseTensor) + func(t._data) + return t + + @implements(aten.mean.default) -def mean_default(t: Tensor) -> Tensor: +def mean(t: Tensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) return aten.sum.default(t._data) / t.numel() @implements(aten.sum.default) -def sum_default(t: Tensor) -> Tensor: +def sum(t: Tensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) return aten.sum.default(t._data) From 2b80788e8a5bd6f1eacb63260fbbd8e46319088d Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 23 Oct 2025 16:10:22 +0200 Subject: [PATCH 015/182] Move Pointwise functions definitions. --- .../autogram/diagonal_sparse_tensor.py | 128 +++++++++--------- 1 file changed, 65 insertions(+), 63 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index cd606eb02..fe76a93ef 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -5,69 +5,6 @@ from torch.ops import aten from torch.utils._pytree import tree_map -# pointwise functions applied to one Tensor with `0.0 → 0` -_POINTWISE_FUNCTIONS = { - aten.abs.default, - aten.absolute.default, - aten.neg.default, - aten.negative.default, - aten.sign.default, - aten.sgn.default, - aten.square.default, - aten.fix.default, - aten.floor.default, - aten.ceil.default, - aten.trunc.default, - aten.round.default, - aten.positive.default, - aten.expm1.default, - aten.log1p.default, - aten.sqrt.default, - aten.sin.default, - aten.tan.default, - aten.sinh.default, - aten.tanh.default, - aten.asin.default, - aten.atan.default, - aten.asinh.default, - aten.atanh.default, - aten.erf.default, - aten.erfinv.default, - aten.relu.default, - aten.hardtanh.default, - aten.leaky_relu.default, -} -_IN_PLACE_POINTWISE_FUNCTIONS = { - aten.abs_.default, - aten.absolute_.default, - aten.neg_.default, - aten.negative_.default, - aten.sign_.default, - aten.sgn_.default, - aten.square_.default, - aten.fix_.default, - aten.floor_.default, - aten.ceil_.default, - aten.trunc_.default, - aten.round_.default, - aten.positive_.default, - aten.expm1_.default, - aten.log1p_.default, - aten.sqrt_.default, - aten.sin_.default, - aten.tan_.default, - aten.sinh_.default, - aten.tanh_.default, - aten.asin_.default, - aten.atan_.default, - aten.asinh_.default, - aten.atanh_.default, - aten.erf_.default, - aten.erfinv_.default, - aten.relu_.default, - aten.hardtanh_.default, - aten.leaky_relu_.default, -} _HANDLED_FUNCTIONS = dict() import functools @@ -155,6 +92,71 @@ def diagonal_sparse_tensor(data: Tensor, v_to_p: list[int]) -> Tensor: return DiagonalSparseTensor(data, v_to_p) +# pointwise functions applied to one Tensor with `0.0 → 0` +_POINTWISE_FUNCTIONS = { + aten.abs.default, + aten.absolute.default, + aten.neg.default, + aten.negative.default, + aten.sign.default, + aten.sgn.default, + aten.square.default, + aten.fix.default, + aten.floor.default, + aten.ceil.default, + aten.trunc.default, + aten.round.default, + aten.positive.default, + aten.expm1.default, + aten.log1p.default, + aten.sqrt.default, + aten.sin.default, + aten.tan.default, + aten.sinh.default, + aten.tanh.default, + aten.asin.default, + aten.atan.default, + aten.asinh.default, + aten.atanh.default, + aten.erf.default, + aten.erfinv.default, + aten.relu.default, + aten.hardtanh.default, + aten.leaky_relu.default, +} +_IN_PLACE_POINTWISE_FUNCTIONS = { + aten.abs_.default, + aten.absolute_.default, + aten.neg_.default, + aten.negative_.default, + aten.sign_.default, + aten.sgn_.default, + aten.square_.default, + aten.fix_.default, + aten.floor_.default, + aten.ceil_.default, + aten.trunc_.default, + aten.round_.default, + aten.positive_.default, + aten.expm1_.default, + aten.log1p_.default, + aten.sqrt_.default, + aten.sin_.default, + aten.tan_.default, + aten.sinh_.default, + aten.tanh_.default, + aten.asin_.default, + aten.atan_.default, + aten.asinh_.default, + aten.atanh_.default, + aten.erf_.default, + aten.erfinv_.default, + aten.relu_.default, + aten.hardtanh_.default, + aten.leaky_relu_.default, +} + + for func in _POINTWISE_FUNCTIONS: @implements(func) From 85c8e4157141fbe86db7ef5fe7061a24619e14ab Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 23 Oct 2025 16:11:55 +0200 Subject: [PATCH 016/182] Clean filed of DST, remove virtual shape, it is just the shape, and make `data` and `v_to_p` public. --- .../autogram/diagonal_sparse_tensor.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index fe76a93ef..dbc3b35a5 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -42,19 +42,18 @@ def __new__(cls, data: Tensor, v_to_p: list[int]): return Tensor._make_wrapper_subclass(cls, shape, dtype=data.dtype, device=data.device) def __init__(self, data: Tensor, v_to_p: list[int]): - self._data = data - self._v_to_p = v_to_p - self._v_shape = [data.shape[i] for i in v_to_p] + self.data = data + self.v_to_p = v_to_p def to_dense(self) -> Tensor: - if self._data.ndim == 0: - return self._data - p_index_ranges = [torch.arange(s, device=self._data.device) for s in self._data.shape] + if self.data.ndim == 0: + return self.data + p_index_ranges = [torch.arange(s, device=self.data.device) for s in self.data.shape] p_indices_grid = torch.meshgrid(*p_index_ranges) - v_indices_grid = [p_indices_grid[i] for i in self._v_to_p] + v_indices_grid = [p_indices_grid[i] for i in self.v_to_p] - res = torch.zeros(self.shape, device=self._data.device, dtype=self._data.dtype) - res[v_indices_grid] = self._data + res = torch.zeros(self.shape, device=self.data.device, dtype=self.data.dtype) + res[v_indices_grid] = self.data return res @classmethod @@ -76,8 +75,8 @@ def unwrap_to_dense(t: Tensor): def __repr__(self): return ( - f"DiagonalSparseTensor(data={self._data}, v_to_p_map={self._v_to_p}, shape=" - f"{self._v_shape})" + f"DiagonalSparseTensor(data={self.data}, v_to_p_map={self.v_to_p}, shape=" + f"{self.shape})" ) @@ -162,7 +161,7 @@ def diagonal_sparse_tensor(data: Tensor, v_to_p: list[int]) -> Tensor: @implements(func) def func(t: Tensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) - return diagonal_sparse_tensor(func(t._data), t._v_to_p) + return diagonal_sparse_tensor(func(t.data), t.v_to_p) for func in _IN_PLACE_POINTWISE_FUNCTIONS: @@ -170,17 +169,17 @@ def func(t: Tensor) -> Tensor: @implements(func) def func(t: Tensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) - func(t._data) + func(t.data) return t @implements(aten.mean.default) def mean(t: Tensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) - return aten.sum.default(t._data) / t.numel() + return aten.sum.default(t.data) / t.numel() @implements(aten.sum.default) def sum(t: Tensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) - return aten.sum.default(t._data) + return aten.sum.default(t.data) From 7c0bc4502b1cd1edc9788d00fba5c4e3fd8d3a48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 23 Oct 2025 19:31:33 +0200 Subject: [PATCH 017/182] Use DST for initial jac_output --- src/torchjd/autogram/_engine.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 361743a40..09bc6404f 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -9,6 +9,7 @@ from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms from ._jacobian_computer import AutogradJacobianComputer from ._module_hook_manager import ModuleHookManager +from .diagonal_sparse_tensor import DiagonalSparseTensor class Engine: @@ -173,7 +174,7 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: ) output_dims = list(range(output.ndim)) - jac_output = _make_initial_jac_output(output) + jac_output = DiagonalSparseTensor(torch.ones_like(output), output_dims * 2) vmapped_diff = differentiation for _ in output_dims: @@ -193,15 +194,3 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: gramian_computer.reset() return gramian - - -def _make_initial_jac_output(output: Tensor) -> Tensor: - if output.ndim == 0: - return torch.ones_like(output) - p_index_ranges = [torch.arange(s, device=output.device) for s in output.shape] - p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij") - v_indices_grid = p_indices_grid + p_indices_grid - - res = torch.zeros(list(output.shape) * 2, device=output.device, dtype=output.dtype) - res[v_indices_grid] = 1.0 - return res From 8509a3334fe0cc23743988aeabf225beeb21d12d Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 24 Oct 2025 15:35:23 +0200 Subject: [PATCH 018/182] Add test for to_dense and inplace_pointwise --- .../autogram/diagonal_sparse_tensor.py | 1 - .../autogram/test_diagonal_sparse_tensor.py | 32 +++++++++++++++++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index dbc3b35a5..3bda6b07f 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -136,7 +136,6 @@ def diagonal_sparse_tensor(data: Tensor, v_to_p: list[int]) -> Tensor: aten.ceil_.default, aten.trunc_.default, aten.round_.default, - aten.positive_.default, aten.expm1_.default, aten.log1p_.default, aten.sqrt_.default, diff --git a/tests/unit/autogram/test_diagonal_sparse_tensor.py b/tests/unit/autogram/test_diagonal_sparse_tensor.py index 32d43244e..79f76a51e 100644 --- a/tests/unit/autogram/test_diagonal_sparse_tensor.py +++ b/tests/unit/autogram/test_diagonal_sparse_tensor.py @@ -3,7 +3,23 @@ from torch.testing import assert_close from utils.tensors import randn_, zeros_ -from torchjd.autogram.diagonal_sparse_tensor import _POINTWISE_FUNCTIONS, DiagonalSparseTensor +from torchjd.autogram.diagonal_sparse_tensor import ( + _IN_PLACE_POINTWISE_FUNCTIONS, + _POINTWISE_FUNCTIONS, + DiagonalSparseTensor, +) + + +def test_to_dense(): + n = 2 + m = 3 + a = randn_([m, n]) + b = DiagonalSparseTensor(a, [0, 1, 1, 0]) + c = b.to_dense() + + for i in range(n): + for j in range(m): + assert c[i, j, j, i] == a[i, j] @mark.parametrize( @@ -19,7 +35,7 @@ [1, 2, 3], ], ) -def test_diagonal_spase_tensor_scalar(shape: list[int]): +def test_diagonal_sparse_tensor_scalar(shape: list[int]): a = randn_(shape) b = DiagonalSparseTensor(a, list(range(len(shape)))) @@ -60,6 +76,18 @@ def test_pointwise(func): assert_close(res, func(c), equal_nan=True) +@mark.parametrize("func", _IN_PLACE_POINTWISE_FUNCTIONS) +def test_inplace_pointwise(func): + dim = 10 + a = randn_([dim]) + b = DiagonalSparseTensor(a, [0, 0]) + c = b.to_dense() + func(b) + assert isinstance(b, DiagonalSparseTensor) + + assert_close(b, func(c), equal_nan=True) + + @mark.parametrize("func", [torch.mean, torch.sum]) def test_unary(func): dim = 10 From 0eed3a313355ae8485db726ffecc97eee227a2fa Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 24 Oct 2025 15:35:27 +0200 Subject: [PATCH 019/182] Revert "Clean filed of DST, remove virtual shape, it is just the shape, and make `data` and `v_to_p` public." This reverts commit 85c8e4157141fbe86db7ef5fe7061a24619e14ab. --- .../autogram/diagonal_sparse_tensor.py | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 3bda6b07f..a104c48a2 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -42,18 +42,19 @@ def __new__(cls, data: Tensor, v_to_p: list[int]): return Tensor._make_wrapper_subclass(cls, shape, dtype=data.dtype, device=data.device) def __init__(self, data: Tensor, v_to_p: list[int]): - self.data = data - self.v_to_p = v_to_p + self._data = data + self._v_to_p = v_to_p + self._v_shape = [data.shape[i] for i in v_to_p] def to_dense(self) -> Tensor: - if self.data.ndim == 0: - return self.data - p_index_ranges = [torch.arange(s, device=self.data.device) for s in self.data.shape] + if self._data.ndim == 0: + return self._data + p_index_ranges = [torch.arange(s, device=self._data.device) for s in self._data.shape] p_indices_grid = torch.meshgrid(*p_index_ranges) - v_indices_grid = [p_indices_grid[i] for i in self.v_to_p] + v_indices_grid = [p_indices_grid[i] for i in self._v_to_p] - res = torch.zeros(self.shape, device=self.data.device, dtype=self.data.dtype) - res[v_indices_grid] = self.data + res = torch.zeros(self.shape, device=self._data.device, dtype=self._data.dtype) + res[v_indices_grid] = self._data return res @classmethod @@ -75,8 +76,8 @@ def unwrap_to_dense(t: Tensor): def __repr__(self): return ( - f"DiagonalSparseTensor(data={self.data}, v_to_p_map={self.v_to_p}, shape=" - f"{self.shape})" + f"DiagonalSparseTensor(data={self._data}, v_to_p_map={self._v_to_p}, shape=" + f"{self._v_shape})" ) @@ -160,7 +161,7 @@ def diagonal_sparse_tensor(data: Tensor, v_to_p: list[int]) -> Tensor: @implements(func) def func(t: Tensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) - return diagonal_sparse_tensor(func(t.data), t.v_to_p) + return diagonal_sparse_tensor(func(t._data), t._v_to_p) for func in _IN_PLACE_POINTWISE_FUNCTIONS: @@ -168,17 +169,17 @@ def func(t: Tensor) -> Tensor: @implements(func) def func(t: Tensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) - func(t.data) + func(t._data) return t @implements(aten.mean.default) def mean(t: Tensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) - return aten.sum.default(t.data) / t.numel() + return aten.sum.default(t._data) / t.numel() @implements(aten.sum.default) def sum(t: Tensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) - return aten.sum.default(t.data) + return aten.sum.default(t._data) From 55a7cbc98c42cf917a39cd99ee475163873c3dd4 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 24 Oct 2025 15:43:31 +0200 Subject: [PATCH 020/182] Make `contiguous_data` and `v_to_p` public. --- .../autogram/diagonal_sparse_tensor.py | 33 ++++++++++--------- .../autogram/test_diagonal_sparse_tensor.py | 2 +- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index a104c48a2..26ea0024d 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -42,19 +42,22 @@ def __new__(cls, data: Tensor, v_to_p: list[int]): return Tensor._make_wrapper_subclass(cls, shape, dtype=data.dtype, device=data.device) def __init__(self, data: Tensor, v_to_p: list[int]): - self._data = data - self._v_to_p = v_to_p - self._v_shape = [data.shape[i] for i in v_to_p] + self.contiguous_data = data # self.data cannot be used here. + self.v_to_p = v_to_p def to_dense(self) -> Tensor: - if self._data.ndim == 0: - return self._data - p_index_ranges = [torch.arange(s, device=self._data.device) for s in self._data.shape] + if self.contiguous_data.ndim == 0: + return self.contiguous_data + p_index_ranges = [ + torch.arange(s, device=self.contiguous_data.device) for s in self.contiguous_data.shape + ] p_indices_grid = torch.meshgrid(*p_index_ranges) - v_indices_grid = [p_indices_grid[i] for i in self._v_to_p] + v_indices_grid = [p_indices_grid[i] for i in self.v_to_p] - res = torch.zeros(self.shape, device=self._data.device, dtype=self._data.dtype) - res[v_indices_grid] = self._data + res = torch.zeros( + self.shape, device=self.contiguous_data.device, dtype=self.contiguous_data.dtype + ) + res[v_indices_grid] = self.contiguous_data return res @classmethod @@ -76,8 +79,8 @@ def unwrap_to_dense(t: Tensor): def __repr__(self): return ( - f"DiagonalSparseTensor(data={self._data}, v_to_p_map={self._v_to_p}, shape=" - f"{self._v_shape})" + f"DiagonalSparseTensor(data={self.contiguous_data}, v_to_p_map={self.v_to_p}, shape=" + f"{self.shape})" ) @@ -161,7 +164,7 @@ def diagonal_sparse_tensor(data: Tensor, v_to_p: list[int]) -> Tensor: @implements(func) def func(t: Tensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) - return diagonal_sparse_tensor(func(t._data), t._v_to_p) + return diagonal_sparse_tensor(func(t.contiguous_data), t.v_to_p) for func in _IN_PLACE_POINTWISE_FUNCTIONS: @@ -169,17 +172,17 @@ def func(t: Tensor) -> Tensor: @implements(func) def func(t: Tensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) - func(t._data) + func(t.contiguous_data) return t @implements(aten.mean.default) def mean(t: Tensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) - return aten.sum.default(t._data) / t.numel() + return aten.sum.default(t.contiguous_data) / t.numel() @implements(aten.sum.default) def sum(t: Tensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) - return aten.sum.default(t._data) + return aten.sum.default(t.contiguous_data) diff --git a/tests/unit/autogram/test_diagonal_sparse_tensor.py b/tests/unit/autogram/test_diagonal_sparse_tensor.py index 79f76a51e..2dcb98b7b 100644 --- a/tests/unit/autogram/test_diagonal_sparse_tensor.py +++ b/tests/unit/autogram/test_diagonal_sparse_tensor.py @@ -13,7 +13,7 @@ def test_to_dense(): n = 2 m = 3 - a = randn_([m, n]) + a = randn_([n, m]) b = DiagonalSparseTensor(a, [0, 1, 1, 0]) c = b.to_dense() From 0ede4ce8259d8e05621339c876fb665f24c30c61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 24 Oct 2025 17:58:18 +0200 Subject: [PATCH 021/182] Add linting comment when importing aten * The reason linters fail here is that ops is a namespace that is dynamically created. But it does exist, it does contain aten, and I've seen some documentation where functions torch.ops.aten.X are used. --- src/torchjd/autogram/diagonal_sparse_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 26ea0024d..c546ceff4 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -2,7 +2,7 @@ import torch from torch import Tensor -from torch.ops import aten +from torch.ops import aten # type: ignore[attr-defined] from torch.utils._pytree import tree_map _HANDLED_FUNCTIONS = dict() From bb165b09cbbb4bdb01c46351ca85992c2fc935e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 24 Oct 2025 18:04:35 +0200 Subject: [PATCH 022/182] Specify indexing="ij" in meshgrid call --- src/torchjd/autogram/diagonal_sparse_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index c546ceff4..0888c5820 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -51,7 +51,7 @@ def to_dense(self) -> Tensor: p_index_ranges = [ torch.arange(s, device=self.contiguous_data.device) for s in self.contiguous_data.shape ] - p_indices_grid = torch.meshgrid(*p_index_ranges) + p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij") v_indices_grid = [p_indices_grid[i] for i in self.v_to_p] res = torch.zeros( From 6e57a3f392ad92551d863372d8509bec9f1cd749 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 24 Oct 2025 18:05:51 +0200 Subject: [PATCH 023/182] Use tuple for v_indices_grid * Multiindexing with list is deprecated but with tuple is ok --- src/torchjd/autogram/diagonal_sparse_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 0888c5820..2b359104f 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -52,7 +52,7 @@ def to_dense(self) -> Tensor: torch.arange(s, device=self.contiguous_data.device) for s in self.contiguous_data.shape ] p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij") - v_indices_grid = [p_indices_grid[i] for i in self.v_to_p] + v_indices_grid = tuple(p_indices_grid[i] for i in self.v_to_p) res = torch.zeros( self.shape, device=self.contiguous_data.device, dtype=self.contiguous_data.dtype From 8351af558a167c236a6751ef31cc352ce62c29de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 24 Oct 2025 18:14:28 +0200 Subject: [PATCH 024/182] Use lists for _POINTWISE_FUNCTIONS and _IN_PLACE_POINTWISE_FUNCTIONS * We currently only iterate over them, so it makes sense to have a list rather than a set * If we have a set, pytest parametrization order is not deterministic and we can't run a specific parametrization --- src/torchjd/autogram/diagonal_sparse_tensor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 2b359104f..bea106dcf 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -96,7 +96,7 @@ def diagonal_sparse_tensor(data: Tensor, v_to_p: list[int]) -> Tensor: # pointwise functions applied to one Tensor with `0.0 → 0` -_POINTWISE_FUNCTIONS = { +_POINTWISE_FUNCTIONS = [ aten.abs.default, aten.absolute.default, aten.neg.default, @@ -126,8 +126,8 @@ def diagonal_sparse_tensor(data: Tensor, v_to_p: list[int]) -> Tensor: aten.relu.default, aten.hardtanh.default, aten.leaky_relu.default, -} -_IN_PLACE_POINTWISE_FUNCTIONS = { +] +_IN_PLACE_POINTWISE_FUNCTIONS = [ aten.abs_.default, aten.absolute_.default, aten.neg_.default, @@ -156,7 +156,7 @@ def diagonal_sparse_tensor(data: Tensor, v_to_p: list[int]) -> Tensor: aten.relu_.default, aten.hardtanh_.default, aten.leaky_relu_.default, -} +] for func in _POINTWISE_FUNCTIONS: From 278bf245595b34f2f09e8162c613a430b54f3132 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 24 Oct 2025 18:17:56 +0200 Subject: [PATCH 025/182] Sort _POINTWISE_FUNCTIONS and _IN_PLACE_POINTWISE_FUNCTIONS * It makes it much easier to find them in pytest results * It makes it much easier to see if there's a duplicate --- .../autogram/diagonal_sparse_tensor.py | 79 ++++++++++--------- 1 file changed, 40 insertions(+), 39 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index bea106dcf..6c1af2812 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -99,63 +99,64 @@ def diagonal_sparse_tensor(data: Tensor, v_to_p: list[int]) -> Tensor: _POINTWISE_FUNCTIONS = [ aten.abs.default, aten.absolute.default, - aten.neg.default, - aten.negative.default, - aten.sign.default, - aten.sgn.default, - aten.square.default, - aten.fix.default, - aten.floor.default, - aten.ceil.default, - aten.trunc.default, - aten.round.default, - aten.positive.default, - aten.expm1.default, - aten.log1p.default, - aten.sqrt.default, - aten.sin.default, - aten.tan.default, - aten.sinh.default, - aten.tanh.default, aten.asin.default, - aten.atan.default, aten.asinh.default, + aten.atan.default, aten.atanh.default, + aten.ceil.default, aten.erf.default, aten.erfinv.default, - aten.relu.default, + aten.expm1.default, + aten.fix.default, + aten.floor.default, aten.hardtanh.default, aten.leaky_relu.default, + aten.log1p.default, + aten.neg.default, + aten.negative.default, + aten.positive.default, + aten.relu.default, + aten.round.default, + aten.sgn.default, + aten.sign.default, + aten.sin.default, + aten.sinh.default, + aten.sqrt.default, + aten.square.default, + aten.tan.default, + aten.tanh.default, + aten.trunc.default, ] + _IN_PLACE_POINTWISE_FUNCTIONS = [ aten.abs_.default, aten.absolute_.default, - aten.neg_.default, - aten.negative_.default, - aten.sign_.default, - aten.sgn_.default, - aten.square_.default, - aten.fix_.default, - aten.floor_.default, - aten.ceil_.default, - aten.trunc_.default, - aten.round_.default, - aten.expm1_.default, - aten.log1p_.default, - aten.sqrt_.default, - aten.sin_.default, - aten.tan_.default, - aten.sinh_.default, - aten.tanh_.default, aten.asin_.default, - aten.atan_.default, aten.asinh_.default, + aten.atan_.default, aten.atanh_.default, + aten.ceil_.default, aten.erf_.default, aten.erfinv_.default, - aten.relu_.default, + aten.expm1_.default, + aten.fix_.default, + aten.floor_.default, aten.hardtanh_.default, aten.leaky_relu_.default, + aten.log1p_.default, + aten.neg_.default, + aten.negative_.default, + aten.relu_.default, + aten.round_.default, + aten.sgn_.default, + aten.sign_.default, + aten.sin_.default, + aten.sinh_.default, + aten.sqrt_.default, + aten.square_.default, + aten.tan_.default, + aten.tanh_.default, + aten.trunc_.default, ] From d14a1a52c17af41736e2d75c01696fdafd201ef7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 24 Oct 2025 18:46:21 +0200 Subject: [PATCH 026/182] Fix pointwise function override * Not sure it's the most elegant way to do that but it works. --- .../autogram/diagonal_sparse_tensor.py | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 6c1af2812..c12134dbb 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -160,23 +160,30 @@ def diagonal_sparse_tensor(data: Tensor, v_to_p: list[int]) -> Tensor: ] -for func in _POINTWISE_FUNCTIONS: - - @implements(func) - def func(t: Tensor) -> Tensor: +def _override_pointwise(op): + @implements(op) + def func_(t: Tensor): assert isinstance(t, DiagonalSparseTensor) - return diagonal_sparse_tensor(func(t.contiguous_data), t.v_to_p) + return diagonal_sparse_tensor(op(t.contiguous_data), t.v_to_p) + return func_ -for func in _IN_PLACE_POINTWISE_FUNCTIONS: - @implements(func) - def func(t: Tensor) -> Tensor: +def _override_inplace_pointwise(op): + @implements(op) + def func_(t: Tensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) - func(t.contiguous_data) + op(t.contiguous_data) return t +for func in _POINTWISE_FUNCTIONS: + _override_pointwise(func) + +for func in _IN_PLACE_POINTWISE_FUNCTIONS: + _override_inplace_pointwise(func) + + @implements(aten.mean.default) def mean(t: Tensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) From 17abe7f42a16113a068fdc6639b48de9332cd4a3 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Sat, 25 Oct 2025 17:48:46 +0200 Subject: [PATCH 027/182] Make `contiguous_data` and `v_to_p` public. --- .../autogram/diagonal_sparse_tensor.py | 28 ++++++++++++++++--- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index c12134dbb..d4633ce20 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -1,3 +1,6 @@ +import operator +from itertools import accumulate +from math import prod from typing import Any import torch @@ -23,7 +26,7 @@ def decorator(func): class DiagonalSparseTensor(torch.Tensor): @staticmethod - def __new__(cls, data: Tensor, v_to_p: list[int]): + def __new__(cls, data: Tensor, v_to_p: list[list[int]]): # At the moment, this class is not compositional, so we assert # that the tensor we're wrapping is exactly a Tensor assert type(data) is Tensor @@ -38,13 +41,23 @@ def __new__(cls, data: Tensor, v_to_p: list[int]): # (which is bad!) assert not data.requires_grad or not torch.is_grad_enabled() - shape = [data.shape[i] for i in v_to_p] + shape = [prod(data.shape[i] for i in stride) for stride in v_to_p] return Tensor._make_wrapper_subclass(cls, shape, dtype=data.dtype, device=data.device) - def __init__(self, data: Tensor, v_to_p: list[int]): + def __init__(self, data: Tensor, v_to_p: list[list[int]]): self.contiguous_data = data # self.data cannot be used here. self.v_to_p = v_to_p + # This is a list of strides whose shape matches that of v_to_p except that each element + # is the stride factor of the index to get the right element for the corresponding virtual + # dimension. Stride is the jump necessary to go from one element to the next one in the + # specified dimension. For instance if the i'th element of v_to_p is [0, 1, 2], then the + # i'th element of _strides is [data.shape[1] * data.shape[2], data.shape[2], 1] and so, if + # we index dimension i with j=j_0 * stride[0] + j_1 * stride[1] + j_2 * stride[2], which is + # a unique decomposition, then this corresponds to indexing dimensions v_to_p[i] at indices + # [j_0, j_1, j_2] + self._strides = [list(accumulate([1] + dims[:0:-1], operator.mul))[::-1] for dims in v_to_p] + def to_dense(self) -> Tensor: if self.contiguous_data.ndim == 0: return self.contiguous_data @@ -52,7 +65,14 @@ def to_dense(self) -> Tensor: torch.arange(s, device=self.contiguous_data.device) for s in self.contiguous_data.shape ] p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij") - v_indices_grid = tuple(p_indices_grid[i] for i in self.v_to_p) + + v_indices_grid = list[Tensor]() + for stride, dims in zip(self._strides, self.v_to_p): + stride_ = torch.tensor(stride, device=self.contiguous_data.device, dtype=torch.int) + torch.sum(torch.stack([p_indices_grid[d] for d in dims], dim=-1) * stride_, dim=-1) + # This is supposed to be a vector of shape d_1 * d_2 ... + # whose elements are the coordinates 1 in p_indices_grad[d_1] times stride 1 + # plus coordinates 2 in p_indices_grad[d_2] times stride 2, etc... res = torch.zeros( self.shape, device=self.contiguous_data.device, dtype=self.contiguous_data.dtype From 6bbd7027ce31971ca376b905159ca97de04381e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 26 Oct 2025 02:45:05 +0200 Subject: [PATCH 028/182] Use wraps decorator instead of update_wrapper * This should have the exact same effect --- src/torchjd/autogram/diagonal_sparse_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index d4633ce20..a5363bf99 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -9,14 +9,14 @@ from torch.utils._pytree import tree_map _HANDLED_FUNCTIONS = dict() -import functools +from functools import wraps def implements(torch_function): """Register a torch function override for ScalarTensor""" + @wraps(func) def decorator(func): - functools.update_wrapper(func, torch_function) _HANDLED_FUNCTIONS[torch_function] = func return func From fac5f7d1a505adbaac770b0c4b6cb14ce2799cee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 26 Oct 2025 02:47:20 +0200 Subject: [PATCH 029/182] Make densification explicit in tests before assert_close * Otherwise it's densified inside the assert_close, potentially several times. --- tests/unit/autogram/test_diagonal_sparse_tensor.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/unit/autogram/test_diagonal_sparse_tensor.py b/tests/unit/autogram/test_diagonal_sparse_tensor.py index 2dcb98b7b..fe1979032 100644 --- a/tests/unit/autogram/test_diagonal_sparse_tensor.py +++ b/tests/unit/autogram/test_diagonal_sparse_tensor.py @@ -39,7 +39,7 @@ def test_diagonal_sparse_tensor_scalar(shape: list[int]): a = randn_(shape) b = DiagonalSparseTensor(a, list(range(len(shape)))) - assert_close(a, b) + assert_close(a, b.to_dense()) @mark.parametrize("dim", [1, 2, 3, 4, 5, 10]) @@ -49,7 +49,7 @@ def test_diag_equivalence(dim: int): diag_a = torch.diag(a) - assert_close(b, diag_a) + assert_close(b.to_dense(), diag_a) def test_three_virtual_single_physical(): @@ -61,7 +61,7 @@ def test_three_virtual_single_physical(): for i in range(dim): expected[i, i, i] = a[i] - assert_close(b, expected) + assert_close(b.to_dense(), expected) @mark.parametrize("func", _POINTWISE_FUNCTIONS) @@ -73,7 +73,7 @@ def test_pointwise(func): res = func(b) assert isinstance(res, DiagonalSparseTensor) - assert_close(res, func(c), equal_nan=True) + assert_close(res.to_dense(), func(c), equal_nan=True) @mark.parametrize("func", _IN_PLACE_POINTWISE_FUNCTIONS) @@ -85,7 +85,7 @@ def test_inplace_pointwise(func): func(b) assert isinstance(b, DiagonalSparseTensor) - assert_close(b, func(c), equal_nan=True) + assert_close(b.to_dense(), func(c), equal_nan=True) @mark.parametrize("func", [torch.mean, torch.sum]) @@ -96,4 +96,4 @@ def test_unary(func): c = b.to_dense() res = func(b) - assert_close(res, func(c)) + assert_close(res.to_dense(), func(c)) From 68bb2a34443aa0acb079fa66abed4cdb720425ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 26 Oct 2025 02:48:45 +0200 Subject: [PATCH 030/182] Revert "Make `contiguous_data` and `v_to_p` public." This reverts commit 17abe7f42a16113a068fdc6639b48de9332cd4a3. --- .../autogram/diagonal_sparse_tensor.py | 28 +++---------------- 1 file changed, 4 insertions(+), 24 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index a5363bf99..2ba7a8470 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -1,6 +1,3 @@ -import operator -from itertools import accumulate -from math import prod from typing import Any import torch @@ -26,7 +23,7 @@ def decorator(func): class DiagonalSparseTensor(torch.Tensor): @staticmethod - def __new__(cls, data: Tensor, v_to_p: list[list[int]]): + def __new__(cls, data: Tensor, v_to_p: list[int]): # At the moment, this class is not compositional, so we assert # that the tensor we're wrapping is exactly a Tensor assert type(data) is Tensor @@ -41,23 +38,13 @@ def __new__(cls, data: Tensor, v_to_p: list[list[int]]): # (which is bad!) assert not data.requires_grad or not torch.is_grad_enabled() - shape = [prod(data.shape[i] for i in stride) for stride in v_to_p] + shape = [data.shape[i] for i in v_to_p] return Tensor._make_wrapper_subclass(cls, shape, dtype=data.dtype, device=data.device) - def __init__(self, data: Tensor, v_to_p: list[list[int]]): + def __init__(self, data: Tensor, v_to_p: list[int]): self.contiguous_data = data # self.data cannot be used here. self.v_to_p = v_to_p - # This is a list of strides whose shape matches that of v_to_p except that each element - # is the stride factor of the index to get the right element for the corresponding virtual - # dimension. Stride is the jump necessary to go from one element to the next one in the - # specified dimension. For instance if the i'th element of v_to_p is [0, 1, 2], then the - # i'th element of _strides is [data.shape[1] * data.shape[2], data.shape[2], 1] and so, if - # we index dimension i with j=j_0 * stride[0] + j_1 * stride[1] + j_2 * stride[2], which is - # a unique decomposition, then this corresponds to indexing dimensions v_to_p[i] at indices - # [j_0, j_1, j_2] - self._strides = [list(accumulate([1] + dims[:0:-1], operator.mul))[::-1] for dims in v_to_p] - def to_dense(self) -> Tensor: if self.contiguous_data.ndim == 0: return self.contiguous_data @@ -65,14 +52,7 @@ def to_dense(self) -> Tensor: torch.arange(s, device=self.contiguous_data.device) for s in self.contiguous_data.shape ] p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij") - - v_indices_grid = list[Tensor]() - for stride, dims in zip(self._strides, self.v_to_p): - stride_ = torch.tensor(stride, device=self.contiguous_data.device, dtype=torch.int) - torch.sum(torch.stack([p_indices_grid[d] for d in dims], dim=-1) * stride_, dim=-1) - # This is supposed to be a vector of shape d_1 * d_2 ... - # whose elements are the coordinates 1 in p_indices_grad[d_1] times stride 1 - # plus coordinates 2 in p_indices_grad[d_2] times stride 2, etc... + v_indices_grid = tuple(p_indices_grid[i] for i in self.v_to_p) res = torch.zeros( self.shape, device=self.contiguous_data.device, dtype=self.contiguous_data.dtype From 3db55c37f24a1661a188910fa1c44f1ebe167440 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 26 Oct 2025 03:43:43 +0100 Subject: [PATCH 031/182] Add pow implementation for DST * This fixes the tests for square operations (in place and not in place) --- .../autogram/diagonal_sparse_tensor.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 2ba7a8470..eb0469d23 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -194,3 +194,29 @@ def mean(t: Tensor) -> Tensor: def sum(t: Tensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) return aten.sum.default(t.contiguous_data) + + +@implements(aten.pow.Tensor_Scalar) +def pow_Tensor_Scalar(t: Tensor, exponent: float) -> Tensor: + assert isinstance(t, DiagonalSparseTensor) + + if exponent <= 0: + # Need to densify because we don't have pow(0, exponent) = 0 + return aten.pow.Tensor_Scalar(t.to_dense(), exponent) + + new_contiguous_data = aten.pow.Tensor_Scalar(t.contiguous_data, exponent) + return diagonal_sparse_tensor(new_contiguous_data, t.v_to_p) + + +# Somehow there's no pow_.Tensor_Scalar and pow_.Scalar takes tensor and scalar. +@implements(aten.pow_.Scalar) +def pow__Scalar(t: Tensor, exponent: float) -> Tensor: + assert isinstance(t, DiagonalSparseTensor) + + if exponent <= 0: + # Need to densify because we don't have pow(0, exponent) = 0 + # Note sure if it's even possible to densify in-place, so let's just raise an error. + raise ValueError(f"in-place pow with an exponent of {exponent} (<= 0) is not supported.") + + aten.pow_.Scalar(t.contiguous_data, exponent) + return t From 90b205c3ebcff82971264fd0b97b26bd3d28b563 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 26 Oct 2025 04:20:51 +0100 Subject: [PATCH 032/182] Remove type hint of __torch_dispatch__ so that mypy stops complaining --- src/torchjd/autogram/diagonal_sparse_tensor.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index eb0469d23..c0606c52a 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -1,5 +1,3 @@ -from typing import Any - import torch from torch import Tensor from torch.ops import aten # type: ignore[attr-defined] @@ -61,7 +59,7 @@ def to_dense(self) -> Tensor: return res @classmethod - def __torch_dispatch__(cls, func: {__name__}, types: Any, args: tuple = (), kwargs: Any = None): + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): kwargs = {} if kwargs is None else kwargs if func in _HANDLED_FUNCTIONS: From 2b360842e508f1c863a0b11a11718fcaa4574166 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 26 Oct 2025 04:21:31 +0100 Subject: [PATCH 033/182] Make signature of to_dense match that of Tensor.to_dense. Since we don't support custom dtype or masked_grad, we just assert they're None. --- src/torchjd/autogram/diagonal_sparse_tensor.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index c0606c52a..be19b1df2 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -43,7 +43,12 @@ def __init__(self, data: Tensor, v_to_p: list[int]): self.contiguous_data = data # self.data cannot be used here. self.v_to_p = v_to_p - def to_dense(self) -> Tensor: + def to_dense( + self, dtype: torch.dtype | None = None, *, masked_grad: bool | None = None + ) -> Tensor: + assert dtype is None # We may add support for this later + assert masked_grad is None # We may add support for this later + if self.contiguous_data.ndim == 0: return self.contiguous_data p_index_ranges = [ From 6236baad6880b3e451a32914379da8ddcfadfb25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 26 Oct 2025 04:24:13 +0100 Subject: [PATCH 034/182] Change type: ignore so that mypy doesn't complain * Otherwise it says that it doesn't find the type stubs for torch.ops. I don't think there's a way it can find them anyway. --- src/torchjd/autogram/diagonal_sparse_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index be19b1df2..da44c9c35 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -1,6 +1,6 @@ import torch from torch import Tensor -from torch.ops import aten # type: ignore[attr-defined] +from torch.ops import aten # type: ignore from torch.utils._pytree import tree_map _HANDLED_FUNCTIONS = dict() From 6d28918391df6904a2a4c360fac94dd32734df3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 26 Oct 2025 04:38:42 +0100 Subject: [PATCH 035/182] Add unsqueeze_default implementation for DST --- src/torchjd/autogram/diagonal_sparse_tensor.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index da44c9c35..1d7eac0ee 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -223,3 +223,18 @@ def pow__Scalar(t: Tensor, exponent: float) -> Tensor: aten.pow_.Scalar(t.contiguous_data, exponent) return t + + +@implements(aten.unsqueeze.default) +def unsqueeze_default(t: Tensor, dim: int) -> Tensor: + assert isinstance(t, DiagonalSparseTensor) + assert -t.ndim - 1 <= dim < t.ndim + 1 + + if dim < 0: + dim = t.ndim + dim + 1 + + new_data = aten.unsqueeze.default(t.contiguous_data, -1) + new_v_to_p = [p for p in t.v_to_p] # Deepcopy the list to not modify the original v_to_p + new_v_to_p.insert(dim, new_data.ndim - 1) + + return diagonal_sparse_tensor(new_data, new_v_to_p) From f4a436c66bb5656fb736d017c76dbda4f4439127 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 26 Oct 2025 05:02:00 +0100 Subject: [PATCH 036/182] Add more info in the print when falling back to dense --- src/torchjd/autogram/diagonal_sparse_tensor.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 1d7eac0ee..1893e2369 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -77,7 +77,12 @@ def unwrap_to_dense(t: Tensor): else: return t - print(f"Falling back to dense for {func.__name__}...") + print(f"Falling back to dense for {func.__name__} called with the following args:") + for arg in args: + print(arg) + print("and the following kwargs:") + for k, v in kwargs.items(): + print(f"{k}: {v}") return func(*tree_map(unwrap_to_dense, args), **tree_map(unwrap_to_dense, kwargs)) def __repr__(self): From ad4d843c07ae3e834ec053ba356deacdf237d7b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 26 Oct 2025 05:02:32 +0100 Subject: [PATCH 037/182] Add implementation for trivial views --- src/torchjd/autogram/diagonal_sparse_tensor.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 1893e2369..cecc9298d 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -243,3 +243,13 @@ def unsqueeze_default(t: Tensor, dim: int) -> Tensor: new_v_to_p.insert(dim, new_data.ndim - 1) return diagonal_sparse_tensor(new_data, new_v_to_p) + + +@implements(aten.view.default) +def view_default(t: Tensor, shape: list[int]) -> Tensor: + assert isinstance(t, DiagonalSparseTensor) + + if shape == list(t.shape): + return DiagonalSparseTensor(t.contiguous_data, t.v_to_p) + else: + raise ValueError("Non-trivial view not supported yet.") From b353ff1fa855e2ee7d015b0538fe115f16a12fea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 26 Oct 2025 05:46:14 +0100 Subject: [PATCH 038/182] Add expand_default, div_Scalar, and slice_Tensor --- .../autogram/diagonal_sparse_tensor.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index cecc9298d..7f9c93e05 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -112,6 +112,7 @@ def diagonal_sparse_tensor(data: Tensor, v_to_p: list[int]) -> Tensor: aten.atan.default, aten.atanh.default, aten.ceil.default, + aten.div.Scalar, aten.erf.default, aten.erfinv.default, aten.expm1.default, @@ -253,3 +254,49 @@ def view_default(t: Tensor, shape: list[int]) -> Tensor: return DiagonalSparseTensor(t.contiguous_data, t.v_to_p) else: raise ValueError("Non-trivial view not supported yet.") + + +@implements(aten.expand.default) +def expand_default(t: Tensor, sizes: list[int]) -> Tensor: + # note that sizes could also be just an int, or a torch.Size i think + assert isinstance(t, DiagonalSparseTensor) + assert isinstance(sizes, list) + assert len(sizes) == t.ndim + + new_contiguous_data_shape = [-1] * t.contiguous_data.ndim + + for dim, (original_size, new_size) in enumerate(zip(t.shape, sizes)): + if new_size != original_size: + assert original_size == 1 + + physical_dim = t.v_to_p[dim] + + # Verify that we don't have two virtual dims expanding the same physical dim differently + previous_value = new_contiguous_data_shape[physical_dim] + assert previous_value == -1 or previous_value == new_size + + new_contiguous_data_shape[physical_dim] = new_size + + new_contiguous_data = aten.expand.default(t.contiguous_data, new_contiguous_data_shape) + + # Not sure if it's safe to just provide v_to_p as-is. I think we're supposed to copy it. + return diagonal_sparse_tensor(new_contiguous_data, t.v_to_p) + + +@implements(aten.div.Scalar) +def div_Scalar(t: Tensor, divisor: float) -> Tensor: + assert isinstance(t, DiagonalSparseTensor) + + new_contiguous_data = aten.div.Scalar(t.contiguous_data, divisor) + return diagonal_sparse_tensor(new_contiguous_data, t.v_to_p) + + +@implements(aten.slice.Tensor) +def slice_Tensor(t: Tensor, dim: int, start: int | None, end: int | None, step: int = 1) -> Tensor: + assert isinstance(t, DiagonalSparseTensor) + + physical_dim = t.v_to_p[dim] + + new_contiguous_data = aten.slice.Tensor(t.contiguous_data, physical_dim, start, end, step) + + return diagonal_sparse_tensor(new_contiguous_data, t.v_to_p) From d640266c7ebcce7c0ace9d59847e95a5fc05bca1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 26 Oct 2025 06:03:30 +0100 Subject: [PATCH 039/182] Add mul_Tensor and transpose_int --- .../autogram/diagonal_sparse_tensor.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 7f9c93e05..0ac85809e 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -300,3 +300,23 @@ def slice_Tensor(t: Tensor, dim: int, start: int | None, end: int | None, step: new_contiguous_data = aten.slice.Tensor(t.contiguous_data, physical_dim, start, end, step) return diagonal_sparse_tensor(new_contiguous_data, t.v_to_p) + + +@implements(aten.mul.Tensor) +def mul_Tensor(t1: Tensor, t2: Tensor) -> Tensor: + # Element-wise multiplication where t1 is dense and t2 is DST + assert isinstance(t2, DiagonalSparseTensor) + + new_contiguous_data = aten.mul.Tensor(t1, t2.contiguous_data) + return diagonal_sparse_tensor(new_contiguous_data, t2.v_to_p) + + +@implements(aten.transpose.int) +def transpose_int(t: Tensor, dim0: int, dim1: int) -> Tensor: + assert isinstance(t, DiagonalSparseTensor) + + new_v_to_p = [p for p in t.v_to_p] + new_v_to_p[dim0] = t.v_to_p[dim1] + new_v_to_p[dim1] = t.v_to_p[dim0] + + return diagonal_sparse_tensor(t.contiguous_data, new_v_to_p) From 6145e4bc250073ccd6ad63d52b8e4849339942b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 26 Oct 2025 06:07:11 +0100 Subject: [PATCH 040/182] Use diagonal_sparse_tensor in view_default --- src/torchjd/autogram/diagonal_sparse_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 0ac85809e..074f06470 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -251,7 +251,7 @@ def view_default(t: Tensor, shape: list[int]) -> Tensor: assert isinstance(t, DiagonalSparseTensor) if shape == list(t.shape): - return DiagonalSparseTensor(t.contiguous_data, t.v_to_p) + return diagonal_sparse_tensor(t.contiguous_data, t.v_to_p) else: raise ValueError("Non-trivial view not supported yet.") From 444e7cd538fba162cd83e3b33cd4e379b8a83bfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 26 Oct 2025 06:07:36 +0100 Subject: [PATCH 041/182] Use DiagonalSparseTensor type hint when applicable --- .../autogram/diagonal_sparse_tensor.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 074f06470..7ca2e498c 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -194,19 +194,19 @@ def func_(t: Tensor) -> Tensor: @implements(aten.mean.default) -def mean(t: Tensor) -> Tensor: +def mean(t: DiagonalSparseTensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) return aten.sum.default(t.contiguous_data) / t.numel() @implements(aten.sum.default) -def sum(t: Tensor) -> Tensor: +def sum(t: DiagonalSparseTensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) return aten.sum.default(t.contiguous_data) @implements(aten.pow.Tensor_Scalar) -def pow_Tensor_Scalar(t: Tensor, exponent: float) -> Tensor: +def pow_Tensor_Scalar(t: DiagonalSparseTensor, exponent: float) -> Tensor: assert isinstance(t, DiagonalSparseTensor) if exponent <= 0: @@ -219,7 +219,7 @@ def pow_Tensor_Scalar(t: Tensor, exponent: float) -> Tensor: # Somehow there's no pow_.Tensor_Scalar and pow_.Scalar takes tensor and scalar. @implements(aten.pow_.Scalar) -def pow__Scalar(t: Tensor, exponent: float) -> Tensor: +def pow__Scalar(t: DiagonalSparseTensor, exponent: float) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) if exponent <= 0: @@ -232,7 +232,7 @@ def pow__Scalar(t: Tensor, exponent: float) -> Tensor: @implements(aten.unsqueeze.default) -def unsqueeze_default(t: Tensor, dim: int) -> Tensor: +def unsqueeze_default(t: DiagonalSparseTensor, dim: int) -> Tensor: assert isinstance(t, DiagonalSparseTensor) assert -t.ndim - 1 <= dim < t.ndim + 1 @@ -247,7 +247,7 @@ def unsqueeze_default(t: Tensor, dim: int) -> Tensor: @implements(aten.view.default) -def view_default(t: Tensor, shape: list[int]) -> Tensor: +def view_default(t: DiagonalSparseTensor, shape: list[int]) -> Tensor: assert isinstance(t, DiagonalSparseTensor) if shape == list(t.shape): @@ -257,7 +257,7 @@ def view_default(t: Tensor, shape: list[int]) -> Tensor: @implements(aten.expand.default) -def expand_default(t: Tensor, sizes: list[int]) -> Tensor: +def expand_default(t: DiagonalSparseTensor, sizes: list[int]) -> Tensor: # note that sizes could also be just an int, or a torch.Size i think assert isinstance(t, DiagonalSparseTensor) assert isinstance(sizes, list) @@ -284,7 +284,7 @@ def expand_default(t: Tensor, sizes: list[int]) -> Tensor: @implements(aten.div.Scalar) -def div_Scalar(t: Tensor, divisor: float) -> Tensor: +def div_Scalar(t: DiagonalSparseTensor, divisor: float) -> Tensor: assert isinstance(t, DiagonalSparseTensor) new_contiguous_data = aten.div.Scalar(t.contiguous_data, divisor) @@ -292,7 +292,9 @@ def div_Scalar(t: Tensor, divisor: float) -> Tensor: @implements(aten.slice.Tensor) -def slice_Tensor(t: Tensor, dim: int, start: int | None, end: int | None, step: int = 1) -> Tensor: +def slice_Tensor( + t: DiagonalSparseTensor, dim: int, start: int | None, end: int | None, step: int = 1 +) -> Tensor: assert isinstance(t, DiagonalSparseTensor) physical_dim = t.v_to_p[dim] @@ -303,7 +305,7 @@ def slice_Tensor(t: Tensor, dim: int, start: int | None, end: int | None, step: @implements(aten.mul.Tensor) -def mul_Tensor(t1: Tensor, t2: Tensor) -> Tensor: +def mul_Tensor(t1: Tensor, t2: DiagonalSparseTensor) -> Tensor: # Element-wise multiplication where t1 is dense and t2 is DST assert isinstance(t2, DiagonalSparseTensor) @@ -312,7 +314,7 @@ def mul_Tensor(t1: Tensor, t2: Tensor) -> Tensor: @implements(aten.transpose.int) -def transpose_int(t: Tensor, dim0: int, dim1: int) -> Tensor: +def transpose_int(t: DiagonalSparseTensor, dim0: int, dim1: int) -> Tensor: assert isinstance(t, DiagonalSparseTensor) new_v_to_p = [p for p in t.v_to_p] From 5c859b51f28680497dd158ca5c7cb026ee024a6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 26 Oct 2025 16:51:15 +0100 Subject: [PATCH 042/182] Add debug_info --- src/torchjd/autogram/diagonal_sparse_tensor.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 7ca2e498c..6f7adfdf6 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -91,6 +91,16 @@ def __repr__(self): f"{self.shape})" ) + def debug_info(self) -> str: + info = ( + f"shape: {self.shape}\n" + f"stride(): {self.stride()}\n" + f"v_to_p: {self.v_to_p}\n" + f"contiguous_data.shape: {self.contiguous_data.shape}\n" + f"contiguous_data.stride(): {self.contiguous_data.stride()}" + ) + return info + def diagonal_sparse_tensor(data: Tensor, v_to_p: list[int]) -> Tensor: if not all(0 <= i < data.ndim for i in v_to_p): From 960de6024319bbfb54e5ab73233e62b784eb52b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 26 Oct 2025 16:51:26 +0100 Subject: [PATCH 043/182] Improve error message in view_default --- src/torchjd/autogram/diagonal_sparse_tensor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 6f7adfdf6..339f42385 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -263,7 +263,9 @@ def view_default(t: DiagonalSparseTensor, shape: list[int]) -> Tensor: if shape == list(t.shape): return diagonal_sparse_tensor(t.contiguous_data, t.v_to_p) else: - raise ValueError("Non-trivial view not supported yet.") + raise ValueError( + f"Non-trivial view not supported yet.\n{t.debug_info()}\ntarget shape: {shape}" + ) @implements(aten.expand.default) From e12080ee4e55fa02f7762be6ad005c044c9a89a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 26 Oct 2025 17:27:55 +0100 Subject: [PATCH 044/182] Add multi-dimensional v_to_p --- src/torchjd/autogram/_engine.py | 3 +- .../autogram/diagonal_sparse_tensor.py | 68 +++++++++++++++---- .../autogram/test_diagonal_sparse_tensor.py | 14 ++-- 3 files changed, 63 insertions(+), 22 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 09bc6404f..6925b8f2e 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -174,7 +174,8 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: ) output_dims = list(range(output.ndim)) - jac_output = DiagonalSparseTensor(torch.ones_like(output), output_dims * 2) + v_to_p = [[dim] for dim in output_dims * 2] + jac_output = DiagonalSparseTensor(torch.ones_like(output), v_to_p) vmapped_diff = differentiation for _ in output_dims: diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 339f42385..29df5b910 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -1,3 +1,7 @@ +import operator +from itertools import accumulate +from math import prod + import torch from torch import Tensor from torch.ops import aten # type: ignore @@ -21,7 +25,7 @@ def decorator(func): class DiagonalSparseTensor(torch.Tensor): @staticmethod - def __new__(cls, data: Tensor, v_to_p: list[int]): + def __new__(cls, data: Tensor, v_to_p: list[list[int]]): # At the moment, this class is not compositional, so we assert # that the tensor we're wrapping is exactly a Tensor assert type(data) is Tensor @@ -36,13 +40,27 @@ def __new__(cls, data: Tensor, v_to_p: list[int]): # (which is bad!) assert not data.requires_grad or not torch.is_grad_enabled() - shape = [data.shape[i] for i in v_to_p] + shape = [prod(data.shape[i] for i in dims) for dims in v_to_p] return Tensor._make_wrapper_subclass(cls, shape, dtype=data.dtype, device=data.device) - def __init__(self, data: Tensor, v_to_p: list[int]): + def __init__(self, data: Tensor, v_to_p: list[list[int]]): self.contiguous_data = data # self.data cannot be used here. self.v_to_p = v_to_p + # This is a list of strides whose shape matches that of v_to_p except that each element + # is the stride factor of the index to get the right element for the corresponding virtual + # dimension. Stride is the jump necessary to go from one element to the next one in the + # specified dimension. For instance if the i'th element of v_to_p is [0, 1, 2], then the + # i'th element of _strides is [data.shape[1] * data.shape[2], data.shape[2], 1] and so, if + # we index dimension i with j=j_0 * stride[0] + j_1 * stride[1] + j_2 * stride[2], which is + # a unique decomposition, then this corresponds to indexing dimensions v_to_p[i] at indices + # [j_0, j_1, j_2] + s = data.shape + self._strides = [ + list(accumulate([1] + [s[dim] for dim in dims[:0:-1]], operator.mul))[::-1] + for dims in v_to_p + ] + def to_dense( self, dtype: torch.dtype | None = None, *, masked_grad: bool | None = None ) -> Tensor: @@ -55,12 +73,20 @@ def to_dense( torch.arange(s, device=self.contiguous_data.device) for s in self.contiguous_data.shape ] p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij") - v_indices_grid = tuple(p_indices_grid[i] for i in self.v_to_p) + v_indices_grid = list[Tensor]() + for stride, dims in zip(self._strides, self.v_to_p): + stride_ = torch.tensor(stride, device=self.contiguous_data.device, dtype=torch.int) + v_indices_grid.append( + torch.sum(torch.stack([p_indices_grid[d] for d in dims], dim=-1) * stride_, dim=-1) + ) + # This is supposed to be a vector of shape d_1 * d_2 ... + # whose elements are the coordinates 1 in p_indices_grad[d_1] times stride 1 + # plus coordinates 2 in p_indices_grad[d_2] times stride 2, etc... res = torch.zeros( self.shape, device=self.contiguous_data.device, dtype=self.contiguous_data.dtype ) - res[v_indices_grid] = self.contiguous_data + res[tuple(v_indices_grid)] = self.contiguous_data return res @classmethod @@ -102,13 +128,16 @@ def debug_info(self) -> str: return info -def diagonal_sparse_tensor(data: Tensor, v_to_p: list[int]) -> Tensor: - if not all(0 <= i < data.ndim for i in v_to_p): +def diagonal_sparse_tensor(data: Tensor, v_to_p: list[list[int]]): + if not all(len(dims) > 0 for dims in v_to_p): + raise ValueError(f"All elements of v_to_p must be non-empty lists. Found {v_to_p}.") + if not all(all(0 <= dim < data.ndim for dim in dims) for dims in v_to_p): raise ValueError(f"Elements in v_to_p map to dimensions in data. Found {v_to_p}.") - if len(set(v_to_p)) != data.ndim: + if len(set.union(*[set(dims) for dims in v_to_p])) != data.ndim: raise ValueError("Every dimension in data must appear at least once in v_to_p.") if len(v_to_p) == data.ndim: - return torch.movedim(data, (list(range(data.ndim))), v_to_p) + # v_to_p should only contain lists of 1 dimension. + return torch.movedim(data, (list(range(data.ndim))), [dims[0] for dims in v_to_p]) else: return DiagonalSparseTensor(data, v_to_p) @@ -251,7 +280,7 @@ def unsqueeze_default(t: DiagonalSparseTensor, dim: int) -> Tensor: new_data = aten.unsqueeze.default(t.contiguous_data, -1) new_v_to_p = [p for p in t.v_to_p] # Deepcopy the list to not modify the original v_to_p - new_v_to_p.insert(dim, new_data.ndim - 1) + new_v_to_p.insert(dim, [new_data.ndim - 1]) return diagonal_sparse_tensor(new_data, new_v_to_p) @@ -279,9 +308,15 @@ def expand_default(t: DiagonalSparseTensor, sizes: list[int]) -> Tensor: for dim, (original_size, new_size) in enumerate(zip(t.shape, sizes)): if new_size != original_size: - assert original_size == 1 + if original_size != 1: + raise ValueError("Cannot yet expand dim whose size != 1.") - physical_dim = t.v_to_p[dim] + if len(t.v_to_p[dim]) != 1: + raise ValueError( + "Cannot yet expand virtual dim corresponding to several physical dims" + ) + + physical_dim = t.v_to_p[dim][0] # Verify that we don't have two virtual dims expanding the same physical dim differently previous_value = new_contiguous_data_shape[physical_dim] @@ -309,7 +344,12 @@ def slice_Tensor( ) -> Tensor: assert isinstance(t, DiagonalSparseTensor) - physical_dim = t.v_to_p[dim] + physical_dims = t.v_to_p[dim] + + if len(physical_dims) != 1: + raise ValueError("Cannot yet slice virtual dim corresponding to several physical dims.") + + physical_dim = physical_dims[0] new_contiguous_data = aten.slice.Tensor(t.contiguous_data, physical_dim, start, end, step) @@ -329,7 +369,7 @@ def mul_Tensor(t1: Tensor, t2: DiagonalSparseTensor) -> Tensor: def transpose_int(t: DiagonalSparseTensor, dim0: int, dim1: int) -> Tensor: assert isinstance(t, DiagonalSparseTensor) - new_v_to_p = [p for p in t.v_to_p] + new_v_to_p = [dims for dims in t.v_to_p] new_v_to_p[dim0] = t.v_to_p[dim1] new_v_to_p[dim1] = t.v_to_p[dim0] diff --git a/tests/unit/autogram/test_diagonal_sparse_tensor.py b/tests/unit/autogram/test_diagonal_sparse_tensor.py index fe1979032..2c64dc649 100644 --- a/tests/unit/autogram/test_diagonal_sparse_tensor.py +++ b/tests/unit/autogram/test_diagonal_sparse_tensor.py @@ -14,7 +14,7 @@ def test_to_dense(): n = 2 m = 3 a = randn_([n, m]) - b = DiagonalSparseTensor(a, [0, 1, 1, 0]) + b = DiagonalSparseTensor(a, [[0], [1], [1], [0]]) c = b.to_dense() for i in range(n): @@ -37,7 +37,7 @@ def test_to_dense(): ) def test_diagonal_sparse_tensor_scalar(shape: list[int]): a = randn_(shape) - b = DiagonalSparseTensor(a, list(range(len(shape)))) + b = DiagonalSparseTensor(a, [[dim] for dim in range(len(shape))]) assert_close(a, b.to_dense()) @@ -45,7 +45,7 @@ def test_diagonal_sparse_tensor_scalar(shape: list[int]): @mark.parametrize("dim", [1, 2, 3, 4, 5, 10]) def test_diag_equivalence(dim: int): a = randn_([dim]) - b = DiagonalSparseTensor(a, [0, 0]) + b = DiagonalSparseTensor(a, [[0], [0]]) diag_a = torch.diag(a) @@ -55,7 +55,7 @@ def test_diag_equivalence(dim: int): def test_three_virtual_single_physical(): dim = 10 a = randn_([dim]) - b = DiagonalSparseTensor(a, [0, 0, 0]) + b = DiagonalSparseTensor(a, [[0], [0], [0]]) expected = zeros_([dim, dim, dim]) for i in range(dim): @@ -68,7 +68,7 @@ def test_three_virtual_single_physical(): def test_pointwise(func): dim = 10 a = randn_([dim]) - b = DiagonalSparseTensor(a, [0, 0]) + b = DiagonalSparseTensor(a, [[0], [0]]) c = b.to_dense() res = func(b) assert isinstance(res, DiagonalSparseTensor) @@ -80,7 +80,7 @@ def test_pointwise(func): def test_inplace_pointwise(func): dim = 10 a = randn_([dim]) - b = DiagonalSparseTensor(a, [0, 0]) + b = DiagonalSparseTensor(a, [[0], [0]]) c = b.to_dense() func(b) assert isinstance(b, DiagonalSparseTensor) @@ -92,7 +92,7 @@ def test_inplace_pointwise(func): def test_unary(func): dim = 10 a = randn_([dim]) - b = DiagonalSparseTensor(a, [0, 0]) + b = DiagonalSparseTensor(a, [[0], [0]]) c = b.to_dense() res = func(b) From a38f907b1e58a89bd20779c0dbfcd3eb06f2f5dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 26 Oct 2025 18:56:30 +0100 Subject: [PATCH 045/182] Rename mean and sum to mean_default and sum_default * This avoids shadowing sum --- src/torchjd/autogram/diagonal_sparse_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 29df5b910..dfb06a731 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -233,13 +233,13 @@ def func_(t: Tensor) -> Tensor: @implements(aten.mean.default) -def mean(t: DiagonalSparseTensor) -> Tensor: +def mean_default(t: DiagonalSparseTensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) return aten.sum.default(t.contiguous_data) / t.numel() @implements(aten.sum.default) -def sum(t: DiagonalSparseTensor) -> Tensor: +def sum_default(t: DiagonalSparseTensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) return aten.sum.default(t.contiguous_data) From b7a3f53cecfdc0e4a872ce51a2a1ca83b986be39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 27 Oct 2025 04:14:57 +0100 Subject: [PATCH 046/182] Fix condition and comment in diagonal_sparse_tensor --- src/torchjd/autogram/diagonal_sparse_tensor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index dfb06a731..5cbf770fa 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -135,8 +135,10 @@ def diagonal_sparse_tensor(data: Tensor, v_to_p: list[list[int]]): raise ValueError(f"Elements in v_to_p map to dimensions in data. Found {v_to_p}.") if len(set.union(*[set(dims) for dims in v_to_p])) != data.ndim: raise ValueError("Every dimension in data must appear at least once in v_to_p.") - if len(v_to_p) == data.ndim: - # v_to_p should only contain lists of 1 dimension. + if sum([len(dims) for dims in v_to_p]) == data.ndim: + # In this case, all lists in v_to_p should contain exactly 1 element. + # Also, each physical dimension appears exactly once in the virtual tensor, so it is + # actually dense and can be returned as a dense tensor. return torch.movedim(data, (list(range(data.ndim))), [dims[0] for dims in v_to_p]) else: return DiagonalSparseTensor(data, v_to_p) From aa3d8a02e31c3a2d72a88a3fa16bba94d2f87658 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 27 Oct 2025 05:38:56 +0100 Subject: [PATCH 047/182] Improve implementation of view * Still not working for cases where the contiguous_data tensor would have to be reshaped too --- .../autogram/diagonal_sparse_tensor.py | 43 ++++++++++++++++--- .../autogram/test_diagonal_sparse_tensor.py | 29 +++++++++++++ 2 files changed, 66 insertions(+), 6 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 5cbf770fa..2650beab5 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -289,14 +289,45 @@ def unsqueeze_default(t: DiagonalSparseTensor, dim: int) -> Tensor: @implements(aten.view.default) def view_default(t: DiagonalSparseTensor, shape: list[int]) -> Tensor: + # TODO: add error message when error is raised + # TODO: handle case where the contiguous_data has to be reshaped + assert isinstance(t, DiagonalSparseTensor) - if shape == list(t.shape): - return diagonal_sparse_tensor(t.contiguous_data, t.v_to_p) - else: - raise ValueError( - f"Non-trivial view not supported yet.\n{t.debug_info()}\ntarget shape: {shape}" - ) + if prod(shape) != t.numel(): + raise ValueError(f"shape '{shape}' is invalid for input of size {t.numel()}") + + new_v_to_p = [] + idx = 0 + flat_v_to_p = [dim for dims in t.v_to_p for dim in dims] + for s in shape: + # Always add the first element of the group, before even entering the while. + # This is because both s and t.contiguous_data.shape[flat_v_to_p[idx]] could be equal to 1, + # in which case the while will not even be entered but we still want to add the dimension to + # the group. More generally, it's a bit arbitrary in which groups the dimension of length 1 + # are put, but it should rarely be an issue. + + group = [flat_v_to_p[idx]] + current_product = t.contiguous_data.shape[flat_v_to_p[idx]] + idx += 1 + + while current_product < s: + if idx >= len(flat_v_to_p): + raise ValueError() + + group.append(flat_v_to_p[idx]) + current_product *= t.contiguous_data.shape[flat_v_to_p[idx]] + idx += 1 + + if current_product > s: + raise ValueError() + + new_v_to_p.append(group) + + if idx != len(flat_v_to_p): + raise ValueError(f"idx != len(flat_v_to_p). {idx}; {flat_v_to_p}; {shape}; {t.v_to_p}") + + return diagonal_sparse_tensor(t.contiguous_data, new_v_to_p) @implements(aten.expand.default) diff --git a/tests/unit/autogram/test_diagonal_sparse_tensor.py b/tests/unit/autogram/test_diagonal_sparse_tensor.py index 2c64dc649..98e0ded2f 100644 --- a/tests/unit/autogram/test_diagonal_sparse_tensor.py +++ b/tests/unit/autogram/test_diagonal_sparse_tensor.py @@ -1,5 +1,6 @@ import torch from pytest import mark +from torch.ops import aten # type: ignore from torch.testing import assert_close from utils.tensors import randn_, zeros_ @@ -97,3 +98,31 @@ def test_unary(func): res = func(b) assert_close(res.to_dense(), func(c)) + + +@mark.parametrize( + ["data_shape", "v_to_p", "target_shape"], + [ + ([2, 3], [[0], [0], [1]], [2, 2, 3]), # no change of shape + ([2, 3], [[0], [0, 1]], [2, 6]), # no change of shape + ([2, 3], [[0], [0], [1]], [2, 6]), # squashing 2 dimensions + ([2, 3], [[0], [0, 1]], [2, 2, 3]), # unsquashing into 2 dimensions + ([2, 3], [[0, 0, 1]], [2, 6]), # unsquashing into 2 dimensions + ([2, 3], [[0], [0], [1]], [12]), # squashing 3 dimensions + ([2, 3], [[0, 0, 1]], [2, 2, 3]), # unsquashing into 3 dimensions + ( + [4], + [[0], [0]], + [2, 2, 4], + ), # unsquashing into 2 dimensions, need to split physical dimension + ], +) +def test_view(data_shape: list[int], v_to_p: list[list[int]], target_shape: list[int]): + a = randn_(tuple(data_shape)) + t = DiagonalSparseTensor(a, v_to_p) + + result = aten.view.default(t, target_shape) + expected = t.to_dense().reshape(target_shape) + + assert isinstance(result, DiagonalSparseTensor) + assert torch.all(torch.eq(result.to_dense(), expected)) From c49411fd49b8a1cb6d240ff16a1b93cc2bde82ae Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 27 Oct 2025 09:29:44 +0100 Subject: [PATCH 048/182] Implement `einsum` for `v_to_p: list[int]` --- .../autogram/diagonal_sparse_tensor.py | 97 ++++++++++++++++++- .../autogram/test_diagonal_sparse_tensor.py | 11 +++ 2 files changed, 104 insertions(+), 4 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 2650beab5..b03783af1 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -89,6 +89,15 @@ def to_dense( res[tuple(v_indices_grid)] = self.contiguous_data return res + def physical_to_virtual(self) -> dict[int, list[int]]: + res = dict[int, list[int]]() + for i, j in enumerate(self.v_to_p): + if j not in res: + res[j] = [i] + else: + res[j].append(i) + return res + @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): kwargs = {} if kwargs is None else kwargs @@ -144,6 +153,13 @@ def diagonal_sparse_tensor(data: Tensor, v_to_p: list[list[int]]): return DiagonalSparseTensor(data, v_to_p) +def to_diagonal_sparse_tensor(t: Tensor) -> DiagonalSparseTensor: + if isinstance(t, DiagonalSparseTensor): + return t + else: + return DiagonalSparseTensor(t, list(range(t.ndim))) + + # pointwise functions applied to one Tensor with `0.0 → 0` _POINTWISE_FUNCTIONS = [ aten.abs.default, @@ -250,8 +266,8 @@ def sum_default(t: DiagonalSparseTensor) -> Tensor: def pow_Tensor_Scalar(t: DiagonalSparseTensor, exponent: float) -> Tensor: assert isinstance(t, DiagonalSparseTensor) - if exponent <= 0: - # Need to densify because we don't have pow(0, exponent) = 0 + if exponent <= 0.0: + # Need to densify because we don't have pow(0.0, exponent) = 0.0 return aten.pow.Tensor_Scalar(t.to_dense(), exponent) new_contiguous_data = aten.pow.Tensor_Scalar(t.contiguous_data, exponent) @@ -263,8 +279,8 @@ def pow_Tensor_Scalar(t: DiagonalSparseTensor, exponent: float) -> Tensor: def pow__Scalar(t: DiagonalSparseTensor, exponent: float) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) - if exponent <= 0: - # Need to densify because we don't have pow(0, exponent) = 0 + if exponent <= 0.0: + # Need to densify because we don't have pow(0.0, exponent) = 0.0 # Note sure if it's even possible to densify in-place, so let's just raise an error. raise ValueError(f"in-place pow with an exponent of {exponent} (<= 0) is not supported.") @@ -407,3 +423,76 @@ def transpose_int(t: DiagonalSparseTensor, dim0: int, dim1: int) -> Tensor: new_v_to_p[dim1] = t.v_to_p[dim0] return diagonal_sparse_tensor(t.contiguous_data, new_v_to_p) + + +def einsum(*args: tuple[Tensor, list[int]], output: list[int]) -> Tensor: + # TODO: Handle ellipsis + new_indices = list[list[int]]() + tensors = list[Tensor]() + + index_parents = dict[int, int]() + + def get_representative(index: int) -> int: + if index not in index_parents: + # If an index is not yet in a cluster, put it in its own. + index_parents[index] = index + current = index_parents[index] + if current != index: + # Compress path to representative + index_parents[index] = get_representative(current) + return index_parents[index] + + def group_indices(indices: list[int]) -> None: + first_representative = get_representative(indices[0]) + for i in indices[1:]: + curr_representative = get_representative(i) + index_parents[curr_representative] = first_representative + + for t, indices in args: + if isinstance(t, DiagonalSparseTensor): + tensors.append(t.contiguous_data) + p_to_v = t.physical_to_virtual() + for indices_ in p_to_v.values(): + # elements in indices[indices_] map to the same dimension, they should be clustered + # together + group_indices([indices[i] for i in indices_]) + # record the physical dimensions, index[v] for v in vs will end-up mapping to the same + # final dimension as they were just clustered, so we can take the first, which exists as + # t is a valid DST. + new_indices.append([indices[p_to_v[i][0]] for i in range(t.contiguous_data.ndim)]) + else: + tensors.append(t) + new_indices.append(indices) + + new_indices = [[get_representative(i) for i in indices] for indices in new_indices] + new_output = list[int]() + v_to_p = list[int]() + for i in output: + new_i = get_representative(i) + if new_i in new_output: + v_to_p.append(new_output.index(new_i)) + else: + v_to_p.append(len(new_output)) + new_output.append(new_i) + + alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + equation = ( + ",".join("".join(alphabet[i] for i in indices) for indices in new_indices) + + "->" + + "".join([alphabet[i] for i in new_output]) + ) + + data = torch.einsum(equation, *tensors) + return diagonal_sparse_tensor(data, v_to_p) + + +@implements(aten.bmm.default) +def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor: + assert ( + mat1.ndim == 3 + and mat2.ndim == 3 + and mat1.shape[0] == mat2.shape[0] + and mat1.shape[2] == mat2.shape[1] + ) + + return einsum((mat1, [0, 1, 2]), (mat2, [0, 2, 3]), output=[0, 1, 3]) diff --git a/tests/unit/autogram/test_diagonal_sparse_tensor.py b/tests/unit/autogram/test_diagonal_sparse_tensor.py index 98e0ded2f..8e324d6dc 100644 --- a/tests/unit/autogram/test_diagonal_sparse_tensor.py +++ b/tests/unit/autogram/test_diagonal_sparse_tensor.py @@ -8,6 +8,7 @@ _IN_PLACE_POINTWISE_FUNCTIONS, _POINTWISE_FUNCTIONS, DiagonalSparseTensor, + einsum, ) @@ -23,6 +24,16 @@ def test_to_dense(): assert c[i, j, j, i] == a[i, j] +def test_einsum(): + a = DiagonalSparseTensor(torch.randn([4, 5]), [0, 0, 1]) + b = DiagonalSparseTensor(torch.randn([5, 4]), [1, 0, 0]) + + res = einsum((a, [0, 1, 2]), (b, [0, 2, 3]), output=[0, 1, 3]) + + expected = torch.einsum("ijk,ikl->ijl", a.to_dense(), b.to_dense()) + assert_close(res, expected) + + @mark.parametrize( "shape", [ From 6c0e9ec0071ee8eb2ab86853e380affeb017963b Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 27 Oct 2025 10:52:05 +0100 Subject: [PATCH 049/182] Factor access to physical shape in view. --- src/torchjd/autogram/diagonal_sparse_tensor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index b03783af1..0f21eab91 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -316,6 +316,7 @@ def view_default(t: DiagonalSparseTensor, shape: list[int]) -> Tensor: new_v_to_p = [] idx = 0 flat_v_to_p = [dim for dims in t.v_to_p for dim in dims] + p_shape = t.contiguous_data.shape for s in shape: # Always add the first element of the group, before even entering the while. # This is because both s and t.contiguous_data.shape[flat_v_to_p[idx]] could be equal to 1, @@ -324,7 +325,7 @@ def view_default(t: DiagonalSparseTensor, shape: list[int]) -> Tensor: # are put, but it should rarely be an issue. group = [flat_v_to_p[idx]] - current_product = t.contiguous_data.shape[flat_v_to_p[idx]] + current_product = p_shape[flat_v_to_p[idx]] idx += 1 while current_product < s: @@ -332,7 +333,7 @@ def view_default(t: DiagonalSparseTensor, shape: list[int]) -> Tensor: raise ValueError() group.append(flat_v_to_p[idx]) - current_product *= t.contiguous_data.shape[flat_v_to_p[idx]] + current_product *= p_shape[flat_v_to_p[idx]] idx += 1 if current_product > s: From 2f774dbe5e4d50c6336ba0742c349f161a42e5bf Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 27 Oct 2025 11:33:59 +0100 Subject: [PATCH 050/182] rename `physical_to_virtual` to `p_to_vs`. --- src/torchjd/autogram/diagonal_sparse_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 0f21eab91..d7dfe4938 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -89,7 +89,7 @@ def to_dense( res[tuple(v_indices_grid)] = self.contiguous_data return res - def physical_to_virtual(self) -> dict[int, list[int]]: + def p_to_vs(self) -> dict[int, list[int]]: res = dict[int, list[int]]() for i, j in enumerate(self.v_to_p): if j not in res: @@ -452,7 +452,7 @@ def group_indices(indices: list[int]) -> None: for t, indices in args: if isinstance(t, DiagonalSparseTensor): tensors.append(t.contiguous_data) - p_to_v = t.physical_to_virtual() + p_to_v = t.p_to_vs() for indices_ in p_to_v.values(): # elements in indices[indices_] map to the same dimension, they should be clustered # together From 73441188276f6bcb07c845a15d4b8e40169407a5 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 27 Oct 2025 11:53:42 +0100 Subject: [PATCH 051/182] Refactor p_to_vs. * Now returns list * Improve naming. --- .../autogram/diagonal_sparse_tensor.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index d7dfe4938..9e8289da8 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -89,14 +89,14 @@ def to_dense( res[tuple(v_indices_grid)] = self.contiguous_data return res - def p_to_vs(self) -> dict[int, list[int]]: + def p_to_vs(self) -> list[list[int]]: res = dict[int, list[int]]() - for i, j in enumerate(self.v_to_p): - if j not in res: - res[j] = [i] + for v_dim, p_dims in enumerate(self.v_to_p): + if p_dims not in res: + res[p_dims] = [v_dim] else: - res[j].append(i) - return res + res[p_dims].append(v_dim) + return [res[i] for i in range(len(res))] @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): @@ -157,7 +157,7 @@ def to_diagonal_sparse_tensor(t: Tensor) -> DiagonalSparseTensor: if isinstance(t, DiagonalSparseTensor): return t else: - return DiagonalSparseTensor(t, list(range(t.ndim))) + return DiagonalSparseTensor(t, [[i] for i in range(t.ndim)]) # pointwise functions applied to one Tensor with `0.0 → 0` @@ -431,6 +431,9 @@ def einsum(*args: tuple[Tensor, list[int]], output: list[int]) -> Tensor: new_indices = list[list[int]]() tensors = list[Tensor]() + # If we have an index v for some virtual dim whose corresponding v_to_p is a non-trivial list + # [p_1, ..., p_k], then we have to create fresh sub-indices for each dimension. + index_parents = dict[int, int]() def get_representative(index: int) -> int: @@ -452,15 +455,15 @@ def group_indices(indices: list[int]) -> None: for t, indices in args: if isinstance(t, DiagonalSparseTensor): tensors.append(t.contiguous_data) - p_to_v = t.p_to_vs() - for indices_ in p_to_v.values(): + p_to_vs = t.p_to_vs() + for indices_ in p_to_vs: # elements in indices[indices_] map to the same dimension, they should be clustered # together group_indices([indices[i] for i in indices_]) # record the physical dimensions, index[v] for v in vs will end-up mapping to the same # final dimension as they were just clustered, so we can take the first, which exists as # t is a valid DST. - new_indices.append([indices[p_to_v[i][0]] for i in range(t.contiguous_data.ndim)]) + new_indices.append([indices[vs[0]] for vs in p_to_vs]) else: tensors.append(t) new_indices.append(indices) From 0884495c310c2041afceafa85e57ae9459f1b0a1 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 27 Oct 2025 13:31:51 +0100 Subject: [PATCH 052/182] Adapt `einsum` to `v_to_p: list[list[int]]` --- .../autogram/diagonal_sparse_tensor.py | 106 ++++++++++++------ .../autogram/test_diagonal_sparse_tensor.py | 4 +- 2 files changed, 73 insertions(+), 37 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 9e8289da8..c3682f163 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -89,13 +89,18 @@ def to_dense( res[tuple(v_indices_grid)] = self.contiguous_data return res - def p_to_vs(self) -> list[list[int]]: - res = dict[int, list[int]]() + def p_to_vs(self) -> list[list[tuple[int, int]]]: + """ + A physical dimension is mapped to a list of couples of the form + (virtual_dim, sub_index_in_virtual_dim) + """ + res = dict[int, list[tuple[int, int]]]() for v_dim, p_dims in enumerate(self.v_to_p): - if p_dims not in res: - res[p_dims] = [v_dim] - else: - res[p_dims].append(v_dim) + for i, p_dim in enumerate(p_dims): + if p_dim not in res: + res[p_dim] = [(v_dim, i)] + else: + res[p_dim].append((v_dim, i)) return [res[i] for i in range(len(res))] @classmethod @@ -428,15 +433,26 @@ def transpose_int(t: DiagonalSparseTensor, dim0: int, dim1: int) -> Tensor: def einsum(*args: tuple[Tensor, list[int]], output: list[int]) -> Tensor: # TODO: Handle ellipsis - new_indices = list[list[int]]() - tensors = list[Tensor]() + # TODO: Should we take only DiagonalSparseTensors and leave the responsability to cast to the + # caller? - # If we have an index v for some virtual dim whose corresponding v_to_p is a non-trivial list + # If we have an index v for some virtual dim whose corresponding v_to_ps is a non-trivial list # [p_1, ..., p_k], then we have to create fresh sub-indices for each dimension. - - index_parents = dict[int, int]() - - def get_representative(index: int) -> int: + # For this reason, an index is decomposed into sub-indices that are then independently + # clustered. + # So if an index i in args for some DiagonalSparseTensor corresponds to a v_to_ps [j, k, l], then + # We will consider three indices (i, 0), (i, 1) and (i, 2). + # If furthermore [k] correspond to the v_to_ps of some other tensor with index j, then + # (i, 1) and (j, 0) will be clustered together (and end up being mapped to the same indice in + # the resulting einsum). + # Note that this is a problem if two virtual dimensions (from possibly different + # DiagonaSparseTensors) have the same size but not the same decomposition into physical + # dimension sizes. For now lets leave the responsibility to care about that in the calling + # functions, if we can factor code later on we will. + + index_parents = dict[tuple[int, int], tuple[int, int]]() + + def get_representative(index: tuple[int, int]) -> tuple[int, int]: if index not in index_parents: # If an index is not yet in a cluster, put it in its own. index_parents[index] = index @@ -446,48 +462,65 @@ def get_representative(index: int) -> int: index_parents[index] = get_representative(current) return index_parents[index] - def group_indices(indices: list[int]) -> None: + def group_indices(indices: list[tuple[int, int]]) -> None: first_representative = get_representative(indices[0]) for i in indices[1:]: curr_representative = get_representative(i) index_parents[curr_representative] = first_representative + new_indices_pair = list[list[tuple[int, int]]]() + tensors = list[Tensor]() + indices_to_n_pdims = dict[int, int]() for t, indices in args: if isinstance(t, DiagonalSparseTensor): tensors.append(t.contiguous_data) + for ps, index in zip(t.v_to_p, indices): + if index in indices_to_n_pdims: + assert indices_to_n_pdims[index] == len(ps) + else: + indices_to_n_pdims[index] = len(ps) p_to_vs = t.p_to_vs() for indices_ in p_to_vs: # elements in indices[indices_] map to the same dimension, they should be clustered # together - group_indices([indices[i] for i in indices_]) + group_indices([(indices[i], sub_i) for i, sub_i in indices_]) # record the physical dimensions, index[v] for v in vs will end-up mapping to the same # final dimension as they were just clustered, so we can take the first, which exists as # t is a valid DST. - new_indices.append([indices[vs[0]] for vs in p_to_vs]) + new_indices_pair.append([(indices[vs[0][0]], vs[0][1]) for vs in p_to_vs]) else: tensors.append(t) - new_indices.append(indices) - - new_indices = [[get_representative(i) for i in indices] for indices in new_indices] + new_indices_pair.append([(i, 0) for i in indices]) + + current = 0 + pair_to_int = dict[tuple[int, int], int]() + + def unique_int(pair: tuple[int, int]) -> int: + nonlocal current + if pair in pair_to_int: + return pair_to_int[pair] + pair_to_int[pair] = current + current += 1 + return pair_to_int[pair] + + new_indices = [ + [unique_int(get_representative(i)) for i in indices] for indices in new_indices_pair + ] new_output = list[int]() - v_to_p = list[int]() + v_to_ps = list[list[int]]() for i in output: - new_i = get_representative(i) - if new_i in new_output: - v_to_p.append(new_output.index(new_i)) - else: - v_to_p.append(len(new_output)) - new_output.append(new_i) - - alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" - equation = ( - ",".join("".join(alphabet[i] for i in indices) for indices in new_indices) - + "->" - + "".join([alphabet[i] for i in new_output]) - ) + current_v_to_ps = [] + for j in range(indices_to_n_pdims[i]): + k = unique_int(get_representative((i, j))) + if k in new_output: + current_v_to_ps.append(new_output.index(k)) + else: + current_v_to_ps.append(len(new_output)) + new_output.append(k) + v_to_ps.append(current_v_to_ps) - data = torch.einsum(equation, *tensors) - return diagonal_sparse_tensor(data, v_to_p) + data = torch.einsum(*[x for y in zip(tensors, new_indices) for x in y], new_output) + return diagonal_sparse_tensor(data, v_to_ps) @implements(aten.bmm.default) @@ -499,4 +532,7 @@ def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor: and mat1.shape[2] == mat2.shape[1] ) + # TODO: Verify that if mat1 and/or mat2 are DiagonalSparseTensors, then their dimension `0` have + # the same physical dimension sizes decompositions. + # If not, can reshape to common decomposition? return einsum((mat1, [0, 1, 2]), (mat2, [0, 2, 3]), output=[0, 1, 3]) diff --git a/tests/unit/autogram/test_diagonal_sparse_tensor.py b/tests/unit/autogram/test_diagonal_sparse_tensor.py index 8e324d6dc..ba53e10cf 100644 --- a/tests/unit/autogram/test_diagonal_sparse_tensor.py +++ b/tests/unit/autogram/test_diagonal_sparse_tensor.py @@ -25,8 +25,8 @@ def test_to_dense(): def test_einsum(): - a = DiagonalSparseTensor(torch.randn([4, 5]), [0, 0, 1]) - b = DiagonalSparseTensor(torch.randn([5, 4]), [1, 0, 0]) + a = DiagonalSparseTensor(torch.randn([4, 5]), [[0], [0], [1]]) + b = DiagonalSparseTensor(torch.randn([5, 4]), [[1], [0], [0]]) res = einsum((a, [0, 1, 2]), (b, [0, 2, 3]), output=[0, 1, 3]) From a5fda15e55426961a5eda9e4266e039af9670c08 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 27 Oct 2025 13:35:21 +0100 Subject: [PATCH 053/182] Rename `v_to_p` to `v_to_ps` --- src/torchjd/autogram/_engine.py | 4 +- .../autogram/diagonal_sparse_tensor.py | 106 +++++++++--------- .../autogram/test_diagonal_sparse_tensor.py | 6 +- 3 files changed, 58 insertions(+), 58 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 6925b8f2e..cca94becd 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -174,8 +174,8 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: ) output_dims = list(range(output.ndim)) - v_to_p = [[dim] for dim in output_dims * 2] - jac_output = DiagonalSparseTensor(torch.ones_like(output), v_to_p) + v_to_ps = [[dim] for dim in output_dims * 2] + jac_output = DiagonalSparseTensor(torch.ones_like(output), v_to_ps) vmapped_diff = differentiation for _ in output_dims: diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index c3682f163..e98828e10 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -25,7 +25,7 @@ def decorator(func): class DiagonalSparseTensor(torch.Tensor): @staticmethod - def __new__(cls, data: Tensor, v_to_p: list[list[int]]): + def __new__(cls, data: Tensor, v_to_ps: list[list[int]]): # At the moment, this class is not compositional, so we assert # that the tensor we're wrapping is exactly a Tensor assert type(data) is Tensor @@ -40,25 +40,25 @@ def __new__(cls, data: Tensor, v_to_p: list[list[int]]): # (which is bad!) assert not data.requires_grad or not torch.is_grad_enabled() - shape = [prod(data.shape[i] for i in dims) for dims in v_to_p] + shape = [prod(data.shape[i] for i in dims) for dims in v_to_ps] return Tensor._make_wrapper_subclass(cls, shape, dtype=data.dtype, device=data.device) - def __init__(self, data: Tensor, v_to_p: list[list[int]]): + def __init__(self, data: Tensor, v_to_ps: list[list[int]]): self.contiguous_data = data # self.data cannot be used here. - self.v_to_p = v_to_p + self.v_to_ps = v_to_ps - # This is a list of strides whose shape matches that of v_to_p except that each element + # This is a list of strides whose shape matches that of v_to_ps except that each element # is the stride factor of the index to get the right element for the corresponding virtual # dimension. Stride is the jump necessary to go from one element to the next one in the - # specified dimension. For instance if the i'th element of v_to_p is [0, 1, 2], then the + # specified dimension. For instance if the i'th element of v_to_ps is [0, 1, 2], then the # i'th element of _strides is [data.shape[1] * data.shape[2], data.shape[2], 1] and so, if # we index dimension i with j=j_0 * stride[0] + j_1 * stride[1] + j_2 * stride[2], which is - # a unique decomposition, then this corresponds to indexing dimensions v_to_p[i] at indices + # a unique decomposition, then this corresponds to indexing dimensions v_to_ps[i] at indices # [j_0, j_1, j_2] s = data.shape self._strides = [ list(accumulate([1] + [s[dim] for dim in dims[:0:-1]], operator.mul))[::-1] - for dims in v_to_p + for dims in v_to_ps ] def to_dense( @@ -74,7 +74,7 @@ def to_dense( ] p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij") v_indices_grid = list[Tensor]() - for stride, dims in zip(self._strides, self.v_to_p): + for stride, dims in zip(self._strides, self.v_to_ps): stride_ = torch.tensor(stride, device=self.contiguous_data.device, dtype=torch.int) v_indices_grid.append( torch.sum(torch.stack([p_indices_grid[d] for d in dims], dim=-1) * stride_, dim=-1) @@ -95,7 +95,7 @@ def p_to_vs(self) -> list[list[tuple[int, int]]]: (virtual_dim, sub_index_in_virtual_dim) """ res = dict[int, list[tuple[int, int]]]() - for v_dim, p_dims in enumerate(self.v_to_p): + for v_dim, p_dims in enumerate(self.v_to_ps): for i, p_dim in enumerate(p_dims): if p_dim not in res: res[p_dim] = [(v_dim, i)] @@ -127,7 +127,7 @@ def unwrap_to_dense(t: Tensor): def __repr__(self): return ( - f"DiagonalSparseTensor(data={self.contiguous_data}, v_to_p_map={self.v_to_p}, shape=" + f"DiagonalSparseTensor(data={self.contiguous_data}, v_to_ps_map={self.v_to_ps}, shape=" f"{self.shape})" ) @@ -135,27 +135,27 @@ def debug_info(self) -> str: info = ( f"shape: {self.shape}\n" f"stride(): {self.stride()}\n" - f"v_to_p: {self.v_to_p}\n" + f"v_to_ps: {self.v_to_ps}\n" f"contiguous_data.shape: {self.contiguous_data.shape}\n" f"contiguous_data.stride(): {self.contiguous_data.stride()}" ) return info -def diagonal_sparse_tensor(data: Tensor, v_to_p: list[list[int]]): - if not all(len(dims) > 0 for dims in v_to_p): - raise ValueError(f"All elements of v_to_p must be non-empty lists. Found {v_to_p}.") - if not all(all(0 <= dim < data.ndim for dim in dims) for dims in v_to_p): - raise ValueError(f"Elements in v_to_p map to dimensions in data. Found {v_to_p}.") - if len(set.union(*[set(dims) for dims in v_to_p])) != data.ndim: - raise ValueError("Every dimension in data must appear at least once in v_to_p.") - if sum([len(dims) for dims in v_to_p]) == data.ndim: - # In this case, all lists in v_to_p should contain exactly 1 element. +def diagonal_sparse_tensor(data: Tensor, v_to_ps: list[list[int]]): + if not all(len(dims) > 0 for dims in v_to_ps): + raise ValueError(f"All elements of v_to_ps must be non-empty lists. Found {v_to_ps}.") + if not all(all(0 <= dim < data.ndim for dim in dims) for dims in v_to_ps): + raise ValueError(f"Elements in v_to_ps map to dimensions in data. Found {v_to_ps}.") + if len(set.union(*[set(dims) for dims in v_to_ps])) != data.ndim: + raise ValueError("Every dimension in data must appear at least once in v_to_ps.") + if sum([len(dims) for dims in v_to_ps]) == data.ndim: + # In this case, all lists in v_to_ps should contain exactly 1 element. # Also, each physical dimension appears exactly once in the virtual tensor, so it is # actually dense and can be returned as a dense tensor. - return torch.movedim(data, (list(range(data.ndim))), [dims[0] for dims in v_to_p]) + return torch.movedim(data, (list(range(data.ndim))), [dims[0] for dims in v_to_ps]) else: - return DiagonalSparseTensor(data, v_to_p) + return DiagonalSparseTensor(data, v_to_ps) def to_diagonal_sparse_tensor(t: Tensor) -> DiagonalSparseTensor: @@ -235,7 +235,7 @@ def _override_pointwise(op): @implements(op) def func_(t: Tensor): assert isinstance(t, DiagonalSparseTensor) - return diagonal_sparse_tensor(op(t.contiguous_data), t.v_to_p) + return diagonal_sparse_tensor(op(t.contiguous_data), t.v_to_ps) return func_ @@ -276,7 +276,7 @@ def pow_Tensor_Scalar(t: DiagonalSparseTensor, exponent: float) -> Tensor: return aten.pow.Tensor_Scalar(t.to_dense(), exponent) new_contiguous_data = aten.pow.Tensor_Scalar(t.contiguous_data, exponent) - return diagonal_sparse_tensor(new_contiguous_data, t.v_to_p) + return diagonal_sparse_tensor(new_contiguous_data, t.v_to_ps) # Somehow there's no pow_.Tensor_Scalar and pow_.Scalar takes tensor and scalar. @@ -302,10 +302,10 @@ def unsqueeze_default(t: DiagonalSparseTensor, dim: int) -> Tensor: dim = t.ndim + dim + 1 new_data = aten.unsqueeze.default(t.contiguous_data, -1) - new_v_to_p = [p for p in t.v_to_p] # Deepcopy the list to not modify the original v_to_p - new_v_to_p.insert(dim, [new_data.ndim - 1]) + new_v_to_ps = [p for p in t.v_to_ps] # Deepcopy the list to not modify the original v_to_ps + new_v_to_ps.insert(dim, [new_data.ndim - 1]) - return diagonal_sparse_tensor(new_data, new_v_to_p) + return diagonal_sparse_tensor(new_data, new_v_to_ps) @implements(aten.view.default) @@ -318,38 +318,38 @@ def view_default(t: DiagonalSparseTensor, shape: list[int]) -> Tensor: if prod(shape) != t.numel(): raise ValueError(f"shape '{shape}' is invalid for input of size {t.numel()}") - new_v_to_p = [] + new_v_to_ps = [] idx = 0 - flat_v_to_p = [dim for dims in t.v_to_p for dim in dims] + flat_v_to_ps = [dim for dims in t.v_to_ps for dim in dims] p_shape = t.contiguous_data.shape for s in shape: # Always add the first element of the group, before even entering the while. - # This is because both s and t.contiguous_data.shape[flat_v_to_p[idx]] could be equal to 1, + # This is because both s and t.contiguous_data.shape[flat_v_to_ps[idx]] could be equal to 1, # in which case the while will not even be entered but we still want to add the dimension to # the group. More generally, it's a bit arbitrary in which groups the dimension of length 1 # are put, but it should rarely be an issue. - group = [flat_v_to_p[idx]] - current_product = p_shape[flat_v_to_p[idx]] + group = [flat_v_to_ps[idx]] + current_product = p_shape[flat_v_to_ps[idx]] idx += 1 while current_product < s: - if idx >= len(flat_v_to_p): + if idx >= len(flat_v_to_ps): raise ValueError() - group.append(flat_v_to_p[idx]) - current_product *= p_shape[flat_v_to_p[idx]] + group.append(flat_v_to_ps[idx]) + current_product *= p_shape[flat_v_to_ps[idx]] idx += 1 if current_product > s: raise ValueError() - new_v_to_p.append(group) + new_v_to_ps.append(group) - if idx != len(flat_v_to_p): - raise ValueError(f"idx != len(flat_v_to_p). {idx}; {flat_v_to_p}; {shape}; {t.v_to_p}") + if idx != len(flat_v_to_ps): + raise ValueError(f"idx != len(flat_v_to_ps). {idx}; {flat_v_to_ps}; {shape}; {t.v_to_ps}") - return diagonal_sparse_tensor(t.contiguous_data, new_v_to_p) + return diagonal_sparse_tensor(t.contiguous_data, new_v_to_ps) @implements(aten.expand.default) @@ -366,12 +366,12 @@ def expand_default(t: DiagonalSparseTensor, sizes: list[int]) -> Tensor: if original_size != 1: raise ValueError("Cannot yet expand dim whose size != 1.") - if len(t.v_to_p[dim]) != 1: + if len(t.v_to_ps[dim]) != 1: raise ValueError( "Cannot yet expand virtual dim corresponding to several physical dims" ) - physical_dim = t.v_to_p[dim][0] + physical_dim = t.v_to_ps[dim][0] # Verify that we don't have two virtual dims expanding the same physical dim differently previous_value = new_contiguous_data_shape[physical_dim] @@ -381,8 +381,8 @@ def expand_default(t: DiagonalSparseTensor, sizes: list[int]) -> Tensor: new_contiguous_data = aten.expand.default(t.contiguous_data, new_contiguous_data_shape) - # Not sure if it's safe to just provide v_to_p as-is. I think we're supposed to copy it. - return diagonal_sparse_tensor(new_contiguous_data, t.v_to_p) + # Not sure if it's safe to just provide v_to_ps as-is. I think we're supposed to copy it. + return diagonal_sparse_tensor(new_contiguous_data, t.v_to_ps) @implements(aten.div.Scalar) @@ -390,7 +390,7 @@ def div_Scalar(t: DiagonalSparseTensor, divisor: float) -> Tensor: assert isinstance(t, DiagonalSparseTensor) new_contiguous_data = aten.div.Scalar(t.contiguous_data, divisor) - return diagonal_sparse_tensor(new_contiguous_data, t.v_to_p) + return diagonal_sparse_tensor(new_contiguous_data, t.v_to_ps) @implements(aten.slice.Tensor) @@ -399,7 +399,7 @@ def slice_Tensor( ) -> Tensor: assert isinstance(t, DiagonalSparseTensor) - physical_dims = t.v_to_p[dim] + physical_dims = t.v_to_ps[dim] if len(physical_dims) != 1: raise ValueError("Cannot yet slice virtual dim corresponding to several physical dims.") @@ -408,7 +408,7 @@ def slice_Tensor( new_contiguous_data = aten.slice.Tensor(t.contiguous_data, physical_dim, start, end, step) - return diagonal_sparse_tensor(new_contiguous_data, t.v_to_p) + return diagonal_sparse_tensor(new_contiguous_data, t.v_to_ps) @implements(aten.mul.Tensor) @@ -417,18 +417,18 @@ def mul_Tensor(t1: Tensor, t2: DiagonalSparseTensor) -> Tensor: assert isinstance(t2, DiagonalSparseTensor) new_contiguous_data = aten.mul.Tensor(t1, t2.contiguous_data) - return diagonal_sparse_tensor(new_contiguous_data, t2.v_to_p) + return diagonal_sparse_tensor(new_contiguous_data, t2.v_to_ps) @implements(aten.transpose.int) def transpose_int(t: DiagonalSparseTensor, dim0: int, dim1: int) -> Tensor: assert isinstance(t, DiagonalSparseTensor) - new_v_to_p = [dims for dims in t.v_to_p] - new_v_to_p[dim0] = t.v_to_p[dim1] - new_v_to_p[dim1] = t.v_to_p[dim0] + new_v_to_ps = [dims for dims in t.v_to_ps] + new_v_to_ps[dim0] = t.v_to_ps[dim1] + new_v_to_ps[dim1] = t.v_to_ps[dim0] - return diagonal_sparse_tensor(t.contiguous_data, new_v_to_p) + return diagonal_sparse_tensor(t.contiguous_data, new_v_to_ps) def einsum(*args: tuple[Tensor, list[int]], output: list[int]) -> Tensor: @@ -474,7 +474,7 @@ def group_indices(indices: list[tuple[int, int]]) -> None: for t, indices in args: if isinstance(t, DiagonalSparseTensor): tensors.append(t.contiguous_data) - for ps, index in zip(t.v_to_p, indices): + for ps, index in zip(t.v_to_ps, indices): if index in indices_to_n_pdims: assert indices_to_n_pdims[index] == len(ps) else: diff --git a/tests/unit/autogram/test_diagonal_sparse_tensor.py b/tests/unit/autogram/test_diagonal_sparse_tensor.py index ba53e10cf..8573e85af 100644 --- a/tests/unit/autogram/test_diagonal_sparse_tensor.py +++ b/tests/unit/autogram/test_diagonal_sparse_tensor.py @@ -112,7 +112,7 @@ def test_unary(func): @mark.parametrize( - ["data_shape", "v_to_p", "target_shape"], + ["data_shape", "v_to_ps", "target_shape"], [ ([2, 3], [[0], [0], [1]], [2, 2, 3]), # no change of shape ([2, 3], [[0], [0, 1]], [2, 6]), # no change of shape @@ -128,9 +128,9 @@ def test_unary(func): ), # unsquashing into 2 dimensions, need to split physical dimension ], ) -def test_view(data_shape: list[int], v_to_p: list[list[int]], target_shape: list[int]): +def test_view(data_shape: list[int], v_to_ps: list[list[int]], target_shape: list[int]): a = randn_(tuple(data_shape)) - t = DiagonalSparseTensor(a, v_to_p) + t = DiagonalSparseTensor(a, v_to_ps) result = aten.view.default(t, target_shape) expected = t.to_dense().reshape(target_shape) From d03d3a4393774844572e33b6b742f21d5f47e5e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 27 Oct 2025 19:27:35 +0100 Subject: [PATCH 054/182] Remove diagonal_sparse_tensor, remove to_diagonal_sparse_tensor --- .../autogram/diagonal_sparse_tensor.py | 75 ++++++++----------- 1 file changed, 31 insertions(+), 44 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index e98828e10..385fd8aab 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -44,6 +44,15 @@ def __new__(cls, data: Tensor, v_to_ps: list[list[int]]): return Tensor._make_wrapper_subclass(cls, shape, dtype=data.dtype, device=data.device) def __init__(self, data: Tensor, v_to_ps: list[list[int]]): + if not all(len(dims) > 0 for dims in v_to_ps): + raise ValueError(f"All elements of v_to_ps must be non-empty lists. Found {v_to_ps}.") + if not all(all(0 <= dim < data.ndim for dim in dims) for dims in v_to_ps): + raise ValueError( + f"Elements in v_to_ps must map to dimensions in data. Found {v_to_ps}." + ) + if len(set.union(*[set(dims) for dims in v_to_ps])) != data.ndim: + raise ValueError("Every dimension in data must appear at least once in v_to_ps.") + self.contiguous_data = data # self.data cannot be used here. self.v_to_ps = v_to_ps @@ -142,29 +151,6 @@ def debug_info(self) -> str: return info -def diagonal_sparse_tensor(data: Tensor, v_to_ps: list[list[int]]): - if not all(len(dims) > 0 for dims in v_to_ps): - raise ValueError(f"All elements of v_to_ps must be non-empty lists. Found {v_to_ps}.") - if not all(all(0 <= dim < data.ndim for dim in dims) for dims in v_to_ps): - raise ValueError(f"Elements in v_to_ps map to dimensions in data. Found {v_to_ps}.") - if len(set.union(*[set(dims) for dims in v_to_ps])) != data.ndim: - raise ValueError("Every dimension in data must appear at least once in v_to_ps.") - if sum([len(dims) for dims in v_to_ps]) == data.ndim: - # In this case, all lists in v_to_ps should contain exactly 1 element. - # Also, each physical dimension appears exactly once in the virtual tensor, so it is - # actually dense and can be returned as a dense tensor. - return torch.movedim(data, (list(range(data.ndim))), [dims[0] for dims in v_to_ps]) - else: - return DiagonalSparseTensor(data, v_to_ps) - - -def to_diagonal_sparse_tensor(t: Tensor) -> DiagonalSparseTensor: - if isinstance(t, DiagonalSparseTensor): - return t - else: - return DiagonalSparseTensor(t, [[i] for i in range(t.ndim)]) - - # pointwise functions applied to one Tensor with `0.0 → 0` _POINTWISE_FUNCTIONS = [ aten.abs.default, @@ -233,16 +219,16 @@ def to_diagonal_sparse_tensor(t: Tensor) -> DiagonalSparseTensor: def _override_pointwise(op): @implements(op) - def func_(t: Tensor): + def func_(t: DiagonalSparseTensor) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) - return diagonal_sparse_tensor(op(t.contiguous_data), t.v_to_ps) + return DiagonalSparseTensor(op(t.contiguous_data), t.v_to_ps) return func_ def _override_inplace_pointwise(op): @implements(op) - def func_(t: Tensor) -> Tensor: + def func_(t: DiagonalSparseTensor) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) op(t.contiguous_data) return t @@ -268,7 +254,7 @@ def sum_default(t: DiagonalSparseTensor) -> Tensor: @implements(aten.pow.Tensor_Scalar) -def pow_Tensor_Scalar(t: DiagonalSparseTensor, exponent: float) -> Tensor: +def pow_Tensor_Scalar(t: DiagonalSparseTensor, exponent: float) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) if exponent <= 0.0: @@ -276,7 +262,7 @@ def pow_Tensor_Scalar(t: DiagonalSparseTensor, exponent: float) -> Tensor: return aten.pow.Tensor_Scalar(t.to_dense(), exponent) new_contiguous_data = aten.pow.Tensor_Scalar(t.contiguous_data, exponent) - return diagonal_sparse_tensor(new_contiguous_data, t.v_to_ps) + return DiagonalSparseTensor(new_contiguous_data, t.v_to_ps) # Somehow there's no pow_.Tensor_Scalar and pow_.Scalar takes tensor and scalar. @@ -294,7 +280,7 @@ def pow__Scalar(t: DiagonalSparseTensor, exponent: float) -> DiagonalSparseTenso @implements(aten.unsqueeze.default) -def unsqueeze_default(t: DiagonalSparseTensor, dim: int) -> Tensor: +def unsqueeze_default(t: DiagonalSparseTensor, dim: int) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) assert -t.ndim - 1 <= dim < t.ndim + 1 @@ -305,11 +291,11 @@ def unsqueeze_default(t: DiagonalSparseTensor, dim: int) -> Tensor: new_v_to_ps = [p for p in t.v_to_ps] # Deepcopy the list to not modify the original v_to_ps new_v_to_ps.insert(dim, [new_data.ndim - 1]) - return diagonal_sparse_tensor(new_data, new_v_to_ps) + return DiagonalSparseTensor(new_data, new_v_to_ps) @implements(aten.view.default) -def view_default(t: DiagonalSparseTensor, shape: list[int]) -> Tensor: +def view_default(t: DiagonalSparseTensor, shape: list[int]) -> DiagonalSparseTensor: # TODO: add error message when error is raised # TODO: handle case where the contiguous_data has to be reshaped @@ -349,11 +335,11 @@ def view_default(t: DiagonalSparseTensor, shape: list[int]) -> Tensor: if idx != len(flat_v_to_ps): raise ValueError(f"idx != len(flat_v_to_ps). {idx}; {flat_v_to_ps}; {shape}; {t.v_to_ps}") - return diagonal_sparse_tensor(t.contiguous_data, new_v_to_ps) + return DiagonalSparseTensor(t.contiguous_data, new_v_to_ps) @implements(aten.expand.default) -def expand_default(t: DiagonalSparseTensor, sizes: list[int]) -> Tensor: +def expand_default(t: DiagonalSparseTensor, sizes: list[int]) -> DiagonalSparseTensor: # note that sizes could also be just an int, or a torch.Size i think assert isinstance(t, DiagonalSparseTensor) assert isinstance(sizes, list) @@ -382,21 +368,21 @@ def expand_default(t: DiagonalSparseTensor, sizes: list[int]) -> Tensor: new_contiguous_data = aten.expand.default(t.contiguous_data, new_contiguous_data_shape) # Not sure if it's safe to just provide v_to_ps as-is. I think we're supposed to copy it. - return diagonal_sparse_tensor(new_contiguous_data, t.v_to_ps) + return DiagonalSparseTensor(new_contiguous_data, t.v_to_ps) @implements(aten.div.Scalar) -def div_Scalar(t: DiagonalSparseTensor, divisor: float) -> Tensor: +def div_Scalar(t: DiagonalSparseTensor, divisor: float) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) new_contiguous_data = aten.div.Scalar(t.contiguous_data, divisor) - return diagonal_sparse_tensor(new_contiguous_data, t.v_to_ps) + return DiagonalSparseTensor(new_contiguous_data, t.v_to_ps) @implements(aten.slice.Tensor) def slice_Tensor( t: DiagonalSparseTensor, dim: int, start: int | None, end: int | None, step: int = 1 -) -> Tensor: +) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) physical_dims = t.v_to_ps[dim] @@ -408,30 +394,30 @@ def slice_Tensor( new_contiguous_data = aten.slice.Tensor(t.contiguous_data, physical_dim, start, end, step) - return diagonal_sparse_tensor(new_contiguous_data, t.v_to_ps) + return DiagonalSparseTensor(new_contiguous_data, t.v_to_ps) @implements(aten.mul.Tensor) -def mul_Tensor(t1: Tensor, t2: DiagonalSparseTensor) -> Tensor: +def mul_Tensor(t1: Tensor, t2: DiagonalSparseTensor) -> DiagonalSparseTensor: # Element-wise multiplication where t1 is dense and t2 is DST assert isinstance(t2, DiagonalSparseTensor) new_contiguous_data = aten.mul.Tensor(t1, t2.contiguous_data) - return diagonal_sparse_tensor(new_contiguous_data, t2.v_to_ps) + return DiagonalSparseTensor(new_contiguous_data, t2.v_to_ps) @implements(aten.transpose.int) -def transpose_int(t: DiagonalSparseTensor, dim0: int, dim1: int) -> Tensor: +def transpose_int(t: DiagonalSparseTensor, dim0: int, dim1: int) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) new_v_to_ps = [dims for dims in t.v_to_ps] new_v_to_ps[dim0] = t.v_to_ps[dim1] new_v_to_ps[dim1] = t.v_to_ps[dim0] - return diagonal_sparse_tensor(t.contiguous_data, new_v_to_ps) + return DiagonalSparseTensor(t.contiguous_data, new_v_to_ps) -def einsum(*args: tuple[Tensor, list[int]], output: list[int]) -> Tensor: +def einsum(*args: tuple[Tensor, list[int]], output: list[int]) -> DiagonalSparseTensor: # TODO: Handle ellipsis # TODO: Should we take only DiagonalSparseTensors and leave the responsability to cast to the # caller? @@ -520,11 +506,12 @@ def unique_int(pair: tuple[int, int]) -> int: v_to_ps.append(current_v_to_ps) data = torch.einsum(*[x for y in zip(tensors, new_indices) for x in y], new_output) - return diagonal_sparse_tensor(data, v_to_ps) + return DiagonalSparseTensor(data, v_to_ps) @implements(aten.bmm.default) def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor: + assert isinstance(mat1, DiagonalSparseTensor) or isinstance(mat2, DiagonalSparseTensor) assert ( mat1.ndim == 3 and mat2.ndim == 3 From 3bf75fb948d6b570cd3956af269fe7197ba23682 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 27 Oct 2025 19:29:47 +0100 Subject: [PATCH 055/182] Remove aten.div.Scalar from _POINTWISE_FUNCTIONS --- src/torchjd/autogram/diagonal_sparse_tensor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 385fd8aab..259fb6fb9 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -160,7 +160,6 @@ def debug_info(self) -> str: aten.atan.default, aten.atanh.default, aten.ceil.default, - aten.div.Scalar, aten.erf.default, aten.erfinv.default, aten.expm1.default, From 35e841ab6fb262a6303c604e514f50063af7623c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 27 Oct 2025 19:35:40 +0100 Subject: [PATCH 056/182] Fix set union to be able to handle empty list of sets --- src/torchjd/autogram/diagonal_sparse_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 259fb6fb9..7ceebbffa 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -50,7 +50,7 @@ def __init__(self, data: Tensor, v_to_ps: list[list[int]]): raise ValueError( f"Elements in v_to_ps must map to dimensions in data. Found {v_to_ps}." ) - if len(set.union(*[set(dims) for dims in v_to_ps])) != data.ndim: + if len(set().union(*[set(dims) for dims in v_to_ps])) != data.ndim: raise ValueError("Every dimension in data must appear at least once in v_to_ps.") self.contiguous_data = data # self.data cannot be used here. From 2d37031c733d43334d5fe1a4fbcf129ab92f0d78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 27 Oct 2025 19:42:42 +0100 Subject: [PATCH 057/182] fix type hint of bmm_default --- src/torchjd/autogram/diagonal_sparse_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 7ceebbffa..0315bc8ef 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -509,7 +509,7 @@ def unique_int(pair: tuple[int, int]) -> int: @implements(aten.bmm.default) -def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor: +def bmm_default(mat1: Tensor, mat2: Tensor) -> DiagonalSparseTensor: assert isinstance(mat1, DiagonalSparseTensor) or isinstance(mat2, DiagonalSparseTensor) assert ( mat1.ndim == 3 From 7a84940f60d64c25547c2d810749d57ee3e91e76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 27 Oct 2025 19:43:33 +0100 Subject: [PATCH 058/182] Add mm_default --- src/torchjd/autogram/diagonal_sparse_tensor.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 0315bc8ef..ad16633a5 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -522,3 +522,11 @@ def bmm_default(mat1: Tensor, mat2: Tensor) -> DiagonalSparseTensor: # the same physical dimension sizes decompositions. # If not, can reshape to common decomposition? return einsum((mat1, [0, 1, 2]), (mat2, [0, 2, 3]), output=[0, 1, 3]) + + +@implements(aten.mm.default) +def mm_default(mat1: Tensor, mat2: Tensor) -> DiagonalSparseTensor: + assert isinstance(mat1, DiagonalSparseTensor) or isinstance(mat2, DiagonalSparseTensor) + assert mat1.ndim == 2 and mat2.ndim == 2 and mat1.shape[1] == mat2.shape[0] + + return einsum((mat1, [0, 1]), (mat2, [1, 2]), output=[0, 2]) From 30adf305a2480929adca34b9006637ebe4903c34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 27 Oct 2025 19:48:00 +0100 Subject: [PATCH 059/182] Use to_dense when comparing result of einsum in test_einsum --- tests/unit/autogram/test_diagonal_sparse_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/autogram/test_diagonal_sparse_tensor.py b/tests/unit/autogram/test_diagonal_sparse_tensor.py index 8573e85af..145ab580c 100644 --- a/tests/unit/autogram/test_diagonal_sparse_tensor.py +++ b/tests/unit/autogram/test_diagonal_sparse_tensor.py @@ -31,7 +31,7 @@ def test_einsum(): res = einsum((a, [0, 1, 2]), (b, [0, 2, 3]), output=[0, 1, 3]) expected = torch.einsum("ijk,ikl->ijl", a.to_dense(), b.to_dense()) - assert_close(res, expected) + assert_close(res.to_dense(), expected) @mark.parametrize( From e6ce46e4d32f7c7c27decd3c2713b00be88bf94a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 27 Oct 2025 20:02:58 +0100 Subject: [PATCH 060/182] Make einsum work with DST only, re-add to_diagonal_sparse_tensor --- .../autogram/diagonal_sparse_tensor.py | 61 +++++++++++-------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index ad16633a5..7c4630d1d 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -151,6 +151,13 @@ def debug_info(self) -> str: return info +def to_diagonal_sparse_tensor(t: Tensor) -> DiagonalSparseTensor: + if isinstance(t, DiagonalSparseTensor): + return t + else: + return DiagonalSparseTensor(t, [[i] for i in range(t.ndim)]) + + # pointwise functions applied to one Tensor with `0.0 → 0` _POINTWISE_FUNCTIONS = [ aten.abs.default, @@ -416,7 +423,9 @@ def transpose_int(t: DiagonalSparseTensor, dim0: int, dim1: int) -> DiagonalSpar return DiagonalSparseTensor(t.contiguous_data, new_v_to_ps) -def einsum(*args: tuple[Tensor, list[int]], output: list[int]) -> DiagonalSparseTensor: +def einsum( + *args: tuple[DiagonalSparseTensor, list[int]], output: list[int] +) -> DiagonalSparseTensor: # TODO: Handle ellipsis # TODO: Should we take only DiagonalSparseTensors and leave the responsability to cast to the # caller? @@ -457,25 +466,22 @@ def group_indices(indices: list[tuple[int, int]]) -> None: tensors = list[Tensor]() indices_to_n_pdims = dict[int, int]() for t, indices in args: - if isinstance(t, DiagonalSparseTensor): - tensors.append(t.contiguous_data) - for ps, index in zip(t.v_to_ps, indices): - if index in indices_to_n_pdims: - assert indices_to_n_pdims[index] == len(ps) - else: - indices_to_n_pdims[index] = len(ps) - p_to_vs = t.p_to_vs() - for indices_ in p_to_vs: - # elements in indices[indices_] map to the same dimension, they should be clustered - # together - group_indices([(indices[i], sub_i) for i, sub_i in indices_]) - # record the physical dimensions, index[v] for v in vs will end-up mapping to the same - # final dimension as they were just clustered, so we can take the first, which exists as - # t is a valid DST. - new_indices_pair.append([(indices[vs[0][0]], vs[0][1]) for vs in p_to_vs]) - else: - tensors.append(t) - new_indices_pair.append([(i, 0) for i in indices]) + assert isinstance(t, DiagonalSparseTensor) + tensors.append(t.contiguous_data) + for ps, index in zip(t.v_to_ps, indices): + if index in indices_to_n_pdims: + assert indices_to_n_pdims[index] == len(ps) + else: + indices_to_n_pdims[index] = len(ps) + p_to_vs = t.p_to_vs() + for indices_ in p_to_vs: + # elements in indices[indices_] map to the same dimension, they should be clustered + # together + group_indices([(indices[i], sub_i) for i, sub_i in indices_]) + # record the physical dimensions, index[v] for v in vs will end-up mapping to the same + # final dimension as they were just clustered, so we can take the first, which exists as + # t is a valid DST. + new_indices_pair.append([(indices[vs[0][0]], vs[0][1]) for vs in p_to_vs]) current = 0 pair_to_int = dict[tuple[int, int], int]() @@ -518,10 +524,12 @@ def bmm_default(mat1: Tensor, mat2: Tensor) -> DiagonalSparseTensor: and mat1.shape[2] == mat2.shape[1] ) - # TODO: Verify that if mat1 and/or mat2 are DiagonalSparseTensors, then their dimension `0` have - # the same physical dimension sizes decompositions. - # If not, can reshape to common decomposition? - return einsum((mat1, [0, 1, 2]), (mat2, [0, 2, 3]), output=[0, 1, 3]) + mat1_ = to_diagonal_sparse_tensor(mat1) + mat2_ = to_diagonal_sparse_tensor(mat2) + + # TODO: Verify that the dimension `0` of mat1_ and mat2_ have the same physical dimension sizes + # decompositions. If not, can reshape to common decomposition? + return einsum((mat1_, [0, 1, 2]), (mat2_, [0, 2, 3]), output=[0, 1, 3]) @implements(aten.mm.default) @@ -529,4 +537,7 @@ def mm_default(mat1: Tensor, mat2: Tensor) -> DiagonalSparseTensor: assert isinstance(mat1, DiagonalSparseTensor) or isinstance(mat2, DiagonalSparseTensor) assert mat1.ndim == 2 and mat2.ndim == 2 and mat1.shape[1] == mat2.shape[0] - return einsum((mat1, [0, 1]), (mat2, [1, 2]), output=[0, 2]) + mat1_ = to_diagonal_sparse_tensor(mat1) + mat2_ = to_diagonal_sparse_tensor(mat2) + + return einsum((mat1_, [0, 1]), (mat2_, [1, 2]), output=[0, 2]) From 5253363c1ba567b7108b62e39bb2a8a0d67d34d7 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 28 Oct 2025 00:11:40 +0100 Subject: [PATCH 061/182] Remove list comprehension --- src/torchjd/autogram/diagonal_sparse_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 7c4630d1d..f01a37d07 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -50,7 +50,7 @@ def __init__(self, data: Tensor, v_to_ps: list[list[int]]): raise ValueError( f"Elements in v_to_ps must map to dimensions in data. Found {v_to_ps}." ) - if len(set().union(*[set(dims) for dims in v_to_ps])) != data.ndim: + if len(set().union(set(dims) for dims in v_to_ps)) != data.ndim: raise ValueError("Every dimension in data must appear at least once in v_to_ps.") self.contiguous_data = data # self.data cannot be used here. From 75d138c024c127041e385520e9a6dc3bd696e483 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 28 Oct 2025 00:48:42 +0100 Subject: [PATCH 062/182] Add sortin function --- src/torchjd/autogram/diagonal_sparse_tensor.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index f01a37d07..13743a958 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -151,6 +151,23 @@ def debug_info(self) -> str: return info +def sort_dst(v_to_ps: list[list[int]]) -> tuple[list[list[int]], list[int]]: + map = dict[int, int]() + curr = 0 + res_v_to_ps = list[list[int]]() + for p_dims in v_to_ps: + new_p_dims = list[int]() + for p_dim in p_dims: + if p_dim not in map: + map[p_dim] = curr + curr += 1 + new_p_dims.append(map[p_dim]) + res_v_to_ps.append(new_p_dims) + + destination = [map[i] for i in range(len(map))] + return res_v_to_ps, destination + + def to_diagonal_sparse_tensor(t: Tensor) -> DiagonalSparseTensor: if isinstance(t, DiagonalSparseTensor): return t From f33107b15b8122b7f75c8726a64d27067bdbec13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 01:21:21 +0100 Subject: [PATCH 063/182] Add test_view2 and new parametrization for test_view --- .../autogram/test_diagonal_sparse_tensor.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/unit/autogram/test_diagonal_sparse_tensor.py b/tests/unit/autogram/test_diagonal_sparse_tensor.py index 145ab580c..9400c3efd 100644 --- a/tests/unit/autogram/test_diagonal_sparse_tensor.py +++ b/tests/unit/autogram/test_diagonal_sparse_tensor.py @@ -126,6 +126,7 @@ def test_unary(func): [[0], [0]], [2, 2, 4], ), # unsquashing into 2 dimensions, need to split physical dimension + ([2, 3, 4], [[0], [0], [1], [2]], [4, 12]), # world boss ], ) def test_view(data_shape: list[int], v_to_ps: list[list[int]], target_shape: list[int]): @@ -137,3 +138,49 @@ def test_view(data_shape: list[int], v_to_ps: list[list[int]], target_shape: lis assert isinstance(result, DiagonalSparseTensor) assert torch.all(torch.eq(result.to_dense(), expected)) + + +@mark.parametrize( + ["data_shape", "v_to_ps", "target_shape", "expected_data_shape", "expected_v_to_ps"], + [ + ([2, 3], [[0], [0], [1]], [2, 2, 3], [2, 3], [[0], [0], [1]]), # no change of shape + ([2, 3], [[0], [0, 1]], [2, 6], [2, 3], [[0], [0, 1]]), # no change of shape + ([2, 3], [[0], [0], [1]], [2, 6], [2, 3], [[0], [0, 1]]), # squashing 2 dimensions + ( + [2, 3], + [[0], [0, 1]], + [2, 2, 3], + [2, 3], + [[0], [0], [1]], + ), # unsquashing into 2 dimensions + ([2, 3], [[0, 0, 1]], [2, 6], [2, 3], [[0], [0, 1]]), # unsquashing into 2 dimensions + ([2, 3], [[0], [0], [1]], [12], [2, 3], [[0, 0, 1]]), # squashing 3 dimensions + ([2, 3], [[0, 0, 1]], [2, 2, 3], [2, 3], [[0], [0], [1]]), # unsquashing into 3 dimensions + ( + [4], + [[0], [0]], + [2, 2, 4], + [2, 2], + [[0], [1], [0, 1]], + ), # unsquashing into 2 dimensions, need to split physical dimension + ([2, 3, 4], [[0], [0], [1], [2]], [4, 12], [2, 12], [[0, 0], [1]]), # world boss + ([2, 12], [[0, 0], [1]], [2, 2, 3, 4], [2, 3, 4], [[0], [0], [1], [2]]), # world boss + ], +) +def test_view2( + data_shape: list[int], + v_to_ps: list[list[int]], + target_shape: list[int], + expected_data_shape: list[int], + expected_v_to_ps: list[list[int]], +): + a = randn_(tuple(data_shape)) + t = DiagonalSparseTensor(a, v_to_ps) + + result = aten.view.default(t, target_shape) + expected = t.to_dense().reshape(target_shape) + + assert isinstance(result, DiagonalSparseTensor) + assert list(result.contiguous_data.shape) == expected_data_shape + assert result.v_to_ps == expected_v_to_ps + assert torch.all(torch.eq(result.to_dense(), expected)) From 17047f6c5199cb169905c7d4d2e7833b6fb8b36e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 01:22:01 +0100 Subject: [PATCH 064/182] Add test_sort_dst --- .../autogram/test_diagonal_sparse_tensor.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/unit/autogram/test_diagonal_sparse_tensor.py b/tests/unit/autogram/test_diagonal_sparse_tensor.py index 9400c3efd..62344cea1 100644 --- a/tests/unit/autogram/test_diagonal_sparse_tensor.py +++ b/tests/unit/autogram/test_diagonal_sparse_tensor.py @@ -9,6 +9,7 @@ _POINTWISE_FUNCTIONS, DiagonalSparseTensor, einsum, + sort_dst, ) @@ -184,3 +185,21 @@ def test_view2( assert list(result.contiguous_data.shape) == expected_data_shape assert result.v_to_ps == expected_v_to_ps assert torch.all(torch.eq(result.to_dense(), expected)) + + +@mark.parametrize( + ["v_to_ps", "expected_sorted_v_to_ps", "expected_destination"], + [ + ([[0], [1, 0], [2, 1, 3]], [[0], [1, 0], [2, 1, 3]], [0, 1, 2, 3]), + ([[1, 0], [3, 2, 1]], [[0, 1], [2, 3, 0]], [1, 0, 3, 2]), + ], +) +def test_sort_dst( + v_to_ps: list[list[int]], + expected_sorted_v_to_ps: list[list[int]], + expected_destination: list[int], +): + sorted_v_to_ps, destination = sort_dst(v_to_ps) + + assert sorted_v_to_ps == expected_sorted_v_to_ps + assert destination == expected_destination From 6d1cfa5342ba6cf0ef3ba3517fee7bf899ccecc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 01:40:04 +0100 Subject: [PATCH 065/182] Revert "Remove list comprehension" This reverts commit 5253363c1ba567b7108b62e39bb2a8a0d67d34d7. --- src/torchjd/autogram/diagonal_sparse_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 13743a958..773969277 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -50,7 +50,7 @@ def __init__(self, data: Tensor, v_to_ps: list[list[int]]): raise ValueError( f"Elements in v_to_ps must map to dimensions in data. Found {v_to_ps}." ) - if len(set().union(set(dims) for dims in v_to_ps)) != data.ndim: + if len(set().union(*[set(dims) for dims in v_to_ps])) != data.ndim: raise ValueError("Every dimension in data must appear at least once in v_to_ps.") self.contiguous_data = data # self.data cannot be used here. From a5390b7f1a41d8e3604aac2730efc17199169dc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 01:41:28 +0100 Subject: [PATCH 066/182] Rename map to mapping in sort_dst --- src/torchjd/autogram/diagonal_sparse_tensor.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 773969277..41e4aa7e7 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -152,19 +152,19 @@ def debug_info(self) -> str: def sort_dst(v_to_ps: list[list[int]]) -> tuple[list[list[int]], list[int]]: - map = dict[int, int]() + mapping = dict[int, int]() curr = 0 res_v_to_ps = list[list[int]]() for p_dims in v_to_ps: new_p_dims = list[int]() for p_dim in p_dims: - if p_dim not in map: - map[p_dim] = curr + if p_dim not in mapping: + mapping[p_dim] = curr curr += 1 - new_p_dims.append(map[p_dim]) + new_p_dims.append(mapping[p_dim]) res_v_to_ps.append(new_p_dims) - destination = [map[i] for i in range(len(map))] + destination = [mapping[i] for i in range(len(mapping))] return res_v_to_ps, destination From 2cc95ea058770c5990b8f77113fd8276f774cb85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 01:53:18 +0100 Subject: [PATCH 067/182] Use empty ps to indicate virtual dim of size one. * Adapt a bunch of functions and usages --- src/torchjd/autogram/_engine.py | 5 +- .../autogram/diagonal_sparse_tensor.py | 80 ++++++++++--------- .../autogram/test_diagonal_sparse_tensor.py | 12 +-- 3 files changed, 50 insertions(+), 47 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index cca94becd..cffda6fd5 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -174,8 +174,9 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: ) output_dims = list(range(output.ndim)) - v_to_ps = [[dim] for dim in output_dims * 2] - jac_output = DiagonalSparseTensor(torch.ones_like(output), v_to_ps) + physical_data = torch.ones_like(output).squeeze() + v_to_ps = [[dim] if output.shape[dim] != 1 else [] for dim in output_dims * 2] + jac_output = DiagonalSparseTensor(physical_data, v_to_ps) vmapped_diff = differentiation for _ in output_dims: diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 41e4aa7e7..d213f6d3e 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -44,8 +44,20 @@ def __new__(cls, data: Tensor, v_to_ps: list[list[int]]): return Tensor._make_wrapper_subclass(cls, shape, dtype=data.dtype, device=data.device) def __init__(self, data: Tensor, v_to_ps: list[list[int]]): - if not all(len(dims) > 0 for dims in v_to_ps): - raise ValueError(f"All elements of v_to_ps must be non-empty lists. Found {v_to_ps}.") + """ + This constructor is made for specifying data and v_to_ps exactly. It should not modify it. + + For this reason, another constructor will be made to either modify the data / v_to_ps to + simplify the result, or to create a dense tensor directly if it's already dense. It could + also be responsible for sorting the first apparition of each physical dim in the flattened + v_to_ps. + """ + + if any(s == 1 for s in data.shape): + raise ValueError( + "Physical data must not contain any dimension of size 1. Found data.shape=" + f"{data.shape}." + ) if not all(all(0 <= dim < data.ndim for dim in dims) for dims in v_to_ps): raise ValueError( f"Elements in v_to_ps must map to dimensions in data. Found {v_to_ps}." @@ -310,11 +322,10 @@ def unsqueeze_default(t: DiagonalSparseTensor, dim: int) -> DiagonalSparseTensor if dim < 0: dim = t.ndim + dim + 1 - new_data = aten.unsqueeze.default(t.contiguous_data, -1) new_v_to_ps = [p for p in t.v_to_ps] # Deepcopy the list to not modify the original v_to_ps - new_v_to_ps.insert(dim, [new_data.ndim - 1]) + new_v_to_ps.insert(dim, []) - return DiagonalSparseTensor(new_data, new_v_to_ps) + return DiagonalSparseTensor(t.contiguous_data, new_v_to_ps) @implements(aten.view.default) @@ -332,15 +343,8 @@ def view_default(t: DiagonalSparseTensor, shape: list[int]) -> DiagonalSparseTen flat_v_to_ps = [dim for dims in t.v_to_ps for dim in dims] p_shape = t.contiguous_data.shape for s in shape: - # Always add the first element of the group, before even entering the while. - # This is because both s and t.contiguous_data.shape[flat_v_to_ps[idx]] could be equal to 1, - # in which case the while will not even be entered but we still want to add the dimension to - # the group. More generally, it's a bit arbitrary in which groups the dimension of length 1 - # are put, but it should rarely be an issue. - - group = [flat_v_to_ps[idx]] - current_product = p_shape[flat_v_to_ps[idx]] - idx += 1 + group = [] + current_product = 1 while current_product < s: if idx >= len(flat_v_to_ps): @@ -366,32 +370,34 @@ def expand_default(t: DiagonalSparseTensor, sizes: list[int]) -> DiagonalSparseT # note that sizes could also be just an int, or a torch.Size i think assert isinstance(t, DiagonalSparseTensor) assert isinstance(sizes, list) - assert len(sizes) == t.ndim - - new_contiguous_data_shape = [-1] * t.contiguous_data.ndim + assert len(sizes) >= t.ndim - for dim, (original_size, new_size) in enumerate(zip(t.shape, sizes)): - if new_size != original_size: - if original_size != 1: - raise ValueError("Cannot yet expand dim whose size != 1.") + for _ in range(len(sizes) - t.ndim): + t = t.unsqueeze(0) - if len(t.v_to_ps[dim]) != 1: - raise ValueError( - "Cannot yet expand virtual dim corresponding to several physical dims" - ) - - physical_dim = t.v_to_ps[dim][0] + assert len(sizes) == t.ndim - # Verify that we don't have two virtual dims expanding the same physical dim differently - previous_value = new_contiguous_data_shape[physical_dim] - assert previous_value == -1 or previous_value == new_size + new_contiguous_data = t.contiguous_data + new_v_to_ps = t.v_to_ps + n_added_physical_dims = 0 + for dim, (ps, orig_size, new_size) in enumerate(zip(t.v_to_ps, t.shape, sizes, strict=True)): + if len(ps) > 0 and orig_size != new_size and new_size != -1: + raise ValueError( + f"Cannot expand dim {dim} of size != 1. Found size {orig_size} and target size " + f"{new_size}." + ) - new_contiguous_data_shape[physical_dim] = new_size + if len(ps) == 0 and new_size != 1 and new_size != -1: + # Add a dimension of size new_size at the end of the physical tensor. + new_physical_shape = list(new_contiguous_data.shape) + [new_size] + new_contiguous_data = new_contiguous_data.unsqueeze(-1).expand(new_physical_shape) + new_v_to_ps[dim] = [t.contiguous_data.ndim + n_added_physical_dims] + n_added_physical_dims += 1 - new_contiguous_data = aten.expand.default(t.contiguous_data, new_contiguous_data_shape) + new_v_to_ps, destination = sort_dst(new_v_to_ps) + new_contiguous_data = new_contiguous_data.movedim(list(range(len(destination))), destination) - # Not sure if it's safe to just provide v_to_ps as-is. I think we're supposed to copy it. - return DiagonalSparseTensor(new_contiguous_data, t.v_to_ps) + return DiagonalSparseTensor(new_contiguous_data, new_v_to_ps) @implements(aten.div.Scalar) @@ -421,9 +427,9 @@ def slice_Tensor( @implements(aten.mul.Tensor) -def mul_Tensor(t1: Tensor, t2: DiagonalSparseTensor) -> DiagonalSparseTensor: - # Element-wise multiplication where t1 is dense and t2 is DST - assert isinstance(t2, DiagonalSparseTensor) +def mul_Tensor(t1: Tensor, t2: Tensor) -> DiagonalSparseTensor: + # Element-wise multiplication + assert isinstance(t1, DiagonalSparseTensor) or isinstance(t2, DiagonalSparseTensor) new_contiguous_data = aten.mul.Tensor(t1, t2.contiguous_data) return DiagonalSparseTensor(new_contiguous_data, t2.v_to_ps) diff --git a/tests/unit/autogram/test_diagonal_sparse_tensor.py b/tests/unit/autogram/test_diagonal_sparse_tensor.py index 62344cea1..569f4957b 100644 --- a/tests/unit/autogram/test_diagonal_sparse_tensor.py +++ b/tests/unit/autogram/test_diagonal_sparse_tensor.py @@ -39,13 +39,9 @@ def test_einsum(): "shape", [ [], - [1], - [3], - [1, 1], - [1, 4], - [3, 1], - [3, 4], - [1, 2, 3], + [2], + [2, 3], + [2, 3, 4], ], ) def test_diagonal_sparse_tensor_scalar(shape: list[int]): @@ -55,7 +51,7 @@ def test_diagonal_sparse_tensor_scalar(shape: list[int]): assert_close(a, b.to_dense()) -@mark.parametrize("dim", [1, 2, 3, 4, 5, 10]) +@mark.parametrize("dim", [2, 3, 4, 5, 10]) def test_diag_equivalence(dim: int): a = randn_([dim]) b = DiagonalSparseTensor(a, [[0], [0]]) From ef5a2b10cce3f939a80224853de8df65067ef8f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 01:55:38 +0100 Subject: [PATCH 068/182] Remove outdated todo in einsum --- src/torchjd/autogram/diagonal_sparse_tensor.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index d213f6d3e..1bc9413bb 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -450,9 +450,6 @@ def einsum( *args: tuple[DiagonalSparseTensor, list[int]], output: list[int] ) -> DiagonalSparseTensor: # TODO: Handle ellipsis - # TODO: Should we take only DiagonalSparseTensors and leave the responsability to cast to the - # caller? - # If we have an index v for some virtual dim whose corresponding v_to_ps is a non-trivial list # [p_1, ..., p_k], then we have to create fresh sub-indices for each dimension. # For this reason, an index is decomposed into sub-indices that are then independently From ea7d4dd07f1891227840032398847c028feda1a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 02:13:26 +0100 Subject: [PATCH 069/182] Improve implementation of sort_dst * Make low-level implementation called first_sort that works on list[int] directly instead of list[list[int]] * Rename sort_dst to first_sort_v_to_ps * Make first_sort_v_to_ps use first_sort, tree_flatten and tree_unflatten * Change test to test the lower-level function, and add parametrizations --- .../autogram/diagonal_sparse_tensor.py | 43 +++++++++++++------ .../autogram/test_diagonal_sparse_tensor.py | 21 +++++---- 2 files changed, 41 insertions(+), 23 deletions(-) diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index 1bc9413bb..03d4f3021 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -5,7 +5,7 @@ import torch from torch import Tensor from torch.ops import aten # type: ignore -from torch.utils._pytree import tree_map +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten _HANDLED_FUNCTIONS = dict() from functools import wraps @@ -163,21 +163,36 @@ def debug_info(self) -> str: return info -def sort_dst(v_to_ps: list[list[int]]) -> tuple[list[list[int]], list[int]]: +def first_sort(input: list[int]) -> tuple[list[int], list[int]]: + """ + Sorts a list of ints so that the first element to appear for the first time is 0, the second is + 1, etc. Elements may appear anywhere after their first appearance. Returns the sorted list and + list corresponding to the destination of each original int. destination[i] = j means that + all elements of value i in input are mapping to j in sorted list. + + Examples: + [1, 0, 3, 2] => [0, 1, 2, 3], [1, 0, 3, 2] + [0, 2, 0, 1] => [0, 1, 0, 2], [0, 2, 1] + [1, 0, 0, 1] => [0, 1, 1, 0], [1, 0] + """ + mapping = dict[int, int]() curr = 0 - res_v_to_ps = list[list[int]]() - for p_dims in v_to_ps: - new_p_dims = list[int]() - for p_dim in p_dims: - if p_dim not in mapping: - mapping[p_dim] = curr - curr += 1 - new_p_dims.append(mapping[p_dim]) - res_v_to_ps.append(new_p_dims) - + output = [] + for v in input: + if v not in mapping: + mapping[v] = curr + curr += 1 + output.append(mapping[v]) destination = [mapping[i] for i in range(len(mapping))] - return res_v_to_ps, destination + + return output, destination + + +def first_sort_v_to_ps(v_to_ps: list[list[int]]) -> tuple[list[list[int]], list[int]]: + flat_v_to_ps, spec = tree_flatten(v_to_ps) + sorted_flat_v_to_ps, destination = first_sort(flat_v_to_ps) + return tree_unflatten(sorted_flat_v_to_ps, spec), destination def to_diagonal_sparse_tensor(t: Tensor) -> DiagonalSparseTensor: @@ -394,7 +409,7 @@ def expand_default(t: DiagonalSparseTensor, sizes: list[int]) -> DiagonalSparseT new_v_to_ps[dim] = [t.contiguous_data.ndim + n_added_physical_dims] n_added_physical_dims += 1 - new_v_to_ps, destination = sort_dst(new_v_to_ps) + new_v_to_ps, destination = first_sort_v_to_ps(new_v_to_ps) new_contiguous_data = new_contiguous_data.movedim(list(range(len(destination))), destination) return DiagonalSparseTensor(new_contiguous_data, new_v_to_ps) diff --git a/tests/unit/autogram/test_diagonal_sparse_tensor.py b/tests/unit/autogram/test_diagonal_sparse_tensor.py index 569f4957b..f03483102 100644 --- a/tests/unit/autogram/test_diagonal_sparse_tensor.py +++ b/tests/unit/autogram/test_diagonal_sparse_tensor.py @@ -9,7 +9,7 @@ _POINTWISE_FUNCTIONS, DiagonalSparseTensor, einsum, - sort_dst, + first_sort, ) @@ -184,18 +184,21 @@ def test_view2( @mark.parametrize( - ["v_to_ps", "expected_sorted_v_to_ps", "expected_destination"], + ["input", "expected_output", "expected_destination"], [ - ([[0], [1, 0], [2, 1, 3]], [[0], [1, 0], [2, 1, 3]], [0, 1, 2, 3]), - ([[1, 0], [3, 2, 1]], [[0, 1], [2, 3, 0]], [1, 0, 3, 2]), + ([0, 1, 0, 2, 1, 3], [0, 1, 0, 2, 1, 3], [0, 1, 2, 3]), # trivial + ([1, 0, 3, 2, 1], [0, 1, 2, 3, 0], [1, 0, 3, 2]), + ([1, 0, 3, 2], [0, 1, 2, 3], [1, 0, 3, 2]), + ([0, 2, 0, 1], [0, 1, 0, 2], [0, 2, 1]), + ([1, 0, 0, 1], [0, 1, 1, 0], [1, 0]), ], ) -def test_sort_dst( - v_to_ps: list[list[int]], - expected_sorted_v_to_ps: list[list[int]], +def test_first_sort( + input: list[int], + expected_output: list[int], expected_destination: list[int], ): - sorted_v_to_ps, destination = sort_dst(v_to_ps) + output, destination = first_sort(input) - assert sorted_v_to_ps == expected_sorted_v_to_ps + assert output == expected_output assert destination == expected_destination From 8403b1138a7587150a70f7089451dadde7edfd06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 02:19:53 +0100 Subject: [PATCH 070/182] Restructure packages * Add sparse package in torchjd * Move diagonal_sparse_tensor to it * Make diagonal_sparse_tensor protected * For now sparse is public. We may want to make it protected. --- src/torchjd/autogram/_engine.py | 3 ++- src/torchjd/sparse/__init__.py | 1 + .../_diagonal_sparse_tensor.py} | 0 tests/unit/sparse/__init__.py | 0 tests/unit/{autogram => sparse}/test_diagonal_sparse_tensor.py | 2 +- 5 files changed, 4 insertions(+), 2 deletions(-) create mode 100644 src/torchjd/sparse/__init__.py rename src/torchjd/{autogram/diagonal_sparse_tensor.py => sparse/_diagonal_sparse_tensor.py} (100%) create mode 100644 tests/unit/sparse/__init__.py rename tests/unit/{autogram => sparse}/test_diagonal_sparse_tensor.py (99%) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index cffda6fd5..0a4be030e 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -4,12 +4,13 @@ from torch import Tensor, nn, vmap from torch.autograd.graph import get_gradient_edge +from torchjd.sparse import DiagonalSparseTensor + from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms from ._jacobian_computer import AutogradJacobianComputer from ._module_hook_manager import ModuleHookManager -from .diagonal_sparse_tensor import DiagonalSparseTensor class Engine: diff --git a/src/torchjd/sparse/__init__.py b/src/torchjd/sparse/__init__.py new file mode 100644 index 000000000..b54263e55 --- /dev/null +++ b/src/torchjd/sparse/__init__.py @@ -0,0 +1 @@ +from ._diagonal_sparse_tensor import DiagonalSparseTensor diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py similarity index 100% rename from src/torchjd/autogram/diagonal_sparse_tensor.py rename to src/torchjd/sparse/_diagonal_sparse_tensor.py diff --git a/tests/unit/sparse/__init__.py b/tests/unit/sparse/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/autogram/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py similarity index 99% rename from tests/unit/autogram/test_diagonal_sparse_tensor.py rename to tests/unit/sparse/test_diagonal_sparse_tensor.py index f03483102..2b1fc3a51 100644 --- a/tests/unit/autogram/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -4,7 +4,7 @@ from torch.testing import assert_close from utils.tensors import randn_, zeros_ -from torchjd.autogram.diagonal_sparse_tensor import ( +from torchjd.sparse._diagonal_sparse_tensor import ( _IN_PLACE_POINTWISE_FUNCTIONS, _POINTWISE_FUNCTIONS, DiagonalSparseTensor, From f2b2ef6d912925e5a83b332d869d025be6d2de41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 02:53:10 +0100 Subject: [PATCH 071/182] Make _HANDLED_FUNCTIONS and implements class attributes of DST --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 61 +++++++++---------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 03d4f3021..990879a23 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -1,4 +1,5 @@ import operator +from functools import wraps from itertools import accumulate from math import prod @@ -7,22 +8,9 @@ from torch.ops import aten # type: ignore from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten -_HANDLED_FUNCTIONS = dict() -from functools import wraps - - -def implements(torch_function): - """Register a torch function override for ScalarTensor""" - - @wraps(func) - def decorator(func): - _HANDLED_FUNCTIONS[torch_function] = func - return func - - return decorator - class DiagonalSparseTensor(torch.Tensor): + _HANDLED_FUNCTIONS = dict() @staticmethod def __new__(cls, data: Tensor, v_to_ps: list[list[int]]): @@ -128,8 +116,8 @@ def p_to_vs(self) -> list[list[tuple[int, int]]]: def __torch_dispatch__(cls, func, types, args=(), kwargs=None): kwargs = {} if kwargs is None else kwargs - if func in _HANDLED_FUNCTIONS: - return _HANDLED_FUNCTIONS[func](*args, **kwargs) + if func in cls._HANDLED_FUNCTIONS: + return cls._HANDLED_FUNCTIONS[func](*args, **kwargs) # --- Fallback: Fold to Dense Tensor --- def unwrap_to_dense(t: Tensor): @@ -162,6 +150,17 @@ def debug_info(self) -> str: ) return info + @classmethod + def implements(cls, torch_function): + """Register a torch function override for ScalarTensor""" + + @wraps(torch_function) + def decorator(func): + cls._HANDLED_FUNCTIONS[torch_function] = func + return func + + return decorator + def first_sort(input: list[int]) -> tuple[list[int], list[int]]: """ @@ -268,7 +267,7 @@ def to_diagonal_sparse_tensor(t: Tensor) -> DiagonalSparseTensor: def _override_pointwise(op): - @implements(op) + @DiagonalSparseTensor.implements(op) def func_(t: DiagonalSparseTensor) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) return DiagonalSparseTensor(op(t.contiguous_data), t.v_to_ps) @@ -277,7 +276,7 @@ def func_(t: DiagonalSparseTensor) -> DiagonalSparseTensor: def _override_inplace_pointwise(op): - @implements(op) + @DiagonalSparseTensor.implements(op) def func_(t: DiagonalSparseTensor) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) op(t.contiguous_data) @@ -291,19 +290,19 @@ def func_(t: DiagonalSparseTensor) -> DiagonalSparseTensor: _override_inplace_pointwise(func) -@implements(aten.mean.default) +@DiagonalSparseTensor.implements(aten.mean.default) def mean_default(t: DiagonalSparseTensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) return aten.sum.default(t.contiguous_data) / t.numel() -@implements(aten.sum.default) +@DiagonalSparseTensor.implements(aten.sum.default) def sum_default(t: DiagonalSparseTensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) return aten.sum.default(t.contiguous_data) -@implements(aten.pow.Tensor_Scalar) +@DiagonalSparseTensor.implements(aten.pow.Tensor_Scalar) def pow_Tensor_Scalar(t: DiagonalSparseTensor, exponent: float) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) @@ -316,7 +315,7 @@ def pow_Tensor_Scalar(t: DiagonalSparseTensor, exponent: float) -> DiagonalSpars # Somehow there's no pow_.Tensor_Scalar and pow_.Scalar takes tensor and scalar. -@implements(aten.pow_.Scalar) +@DiagonalSparseTensor.implements(aten.pow_.Scalar) def pow__Scalar(t: DiagonalSparseTensor, exponent: float) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) @@ -329,7 +328,7 @@ def pow__Scalar(t: DiagonalSparseTensor, exponent: float) -> DiagonalSparseTenso return t -@implements(aten.unsqueeze.default) +@DiagonalSparseTensor.implements(aten.unsqueeze.default) def unsqueeze_default(t: DiagonalSparseTensor, dim: int) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) assert -t.ndim - 1 <= dim < t.ndim + 1 @@ -343,7 +342,7 @@ def unsqueeze_default(t: DiagonalSparseTensor, dim: int) -> DiagonalSparseTensor return DiagonalSparseTensor(t.contiguous_data, new_v_to_ps) -@implements(aten.view.default) +@DiagonalSparseTensor.implements(aten.view.default) def view_default(t: DiagonalSparseTensor, shape: list[int]) -> DiagonalSparseTensor: # TODO: add error message when error is raised # TODO: handle case where the contiguous_data has to be reshaped @@ -380,7 +379,7 @@ def view_default(t: DiagonalSparseTensor, shape: list[int]) -> DiagonalSparseTen return DiagonalSparseTensor(t.contiguous_data, new_v_to_ps) -@implements(aten.expand.default) +@DiagonalSparseTensor.implements(aten.expand.default) def expand_default(t: DiagonalSparseTensor, sizes: list[int]) -> DiagonalSparseTensor: # note that sizes could also be just an int, or a torch.Size i think assert isinstance(t, DiagonalSparseTensor) @@ -415,7 +414,7 @@ def expand_default(t: DiagonalSparseTensor, sizes: list[int]) -> DiagonalSparseT return DiagonalSparseTensor(new_contiguous_data, new_v_to_ps) -@implements(aten.div.Scalar) +@DiagonalSparseTensor.implements(aten.div.Scalar) def div_Scalar(t: DiagonalSparseTensor, divisor: float) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) @@ -423,7 +422,7 @@ def div_Scalar(t: DiagonalSparseTensor, divisor: float) -> DiagonalSparseTensor: return DiagonalSparseTensor(new_contiguous_data, t.v_to_ps) -@implements(aten.slice.Tensor) +@DiagonalSparseTensor.implements(aten.slice.Tensor) def slice_Tensor( t: DiagonalSparseTensor, dim: int, start: int | None, end: int | None, step: int = 1 ) -> DiagonalSparseTensor: @@ -441,7 +440,7 @@ def slice_Tensor( return DiagonalSparseTensor(new_contiguous_data, t.v_to_ps) -@implements(aten.mul.Tensor) +@DiagonalSparseTensor.implements(aten.mul.Tensor) def mul_Tensor(t1: Tensor, t2: Tensor) -> DiagonalSparseTensor: # Element-wise multiplication assert isinstance(t1, DiagonalSparseTensor) or isinstance(t2, DiagonalSparseTensor) @@ -450,7 +449,7 @@ def mul_Tensor(t1: Tensor, t2: Tensor) -> DiagonalSparseTensor: return DiagonalSparseTensor(new_contiguous_data, t2.v_to_ps) -@implements(aten.transpose.int) +@DiagonalSparseTensor.implements(aten.transpose.int) def transpose_int(t: DiagonalSparseTensor, dim0: int, dim1: int) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) @@ -549,7 +548,7 @@ def unique_int(pair: tuple[int, int]) -> int: return DiagonalSparseTensor(data, v_to_ps) -@implements(aten.bmm.default) +@DiagonalSparseTensor.implements(aten.bmm.default) def bmm_default(mat1: Tensor, mat2: Tensor) -> DiagonalSparseTensor: assert isinstance(mat1, DiagonalSparseTensor) or isinstance(mat2, DiagonalSparseTensor) assert ( @@ -567,7 +566,7 @@ def bmm_default(mat1: Tensor, mat2: Tensor) -> DiagonalSparseTensor: return einsum((mat1_, [0, 1, 2]), (mat2_, [0, 2, 3]), output=[0, 1, 3]) -@implements(aten.mm.default) +@DiagonalSparseTensor.implements(aten.mm.default) def mm_default(mat1: Tensor, mat2: Tensor) -> DiagonalSparseTensor: assert isinstance(mat1, DiagonalSparseTensor) or isinstance(mat2, DiagonalSparseTensor) assert mat1.ndim == 2 and mat2.ndim == 2 and mat1.shape[1] == mat2.shape[0] From 97213eb5b33fd146a88a178cdfb2c60ea7f6aac0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 02:57:33 +0100 Subject: [PATCH 072/182] Move pointwise function implementations to the end of the file and avoid name shadowing --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 178 +++++++++--------- 1 file changed, 89 insertions(+), 89 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 990879a23..6676e141f 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -201,95 +201,6 @@ def to_diagonal_sparse_tensor(t: Tensor) -> DiagonalSparseTensor: return DiagonalSparseTensor(t, [[i] for i in range(t.ndim)]) -# pointwise functions applied to one Tensor with `0.0 → 0` -_POINTWISE_FUNCTIONS = [ - aten.abs.default, - aten.absolute.default, - aten.asin.default, - aten.asinh.default, - aten.atan.default, - aten.atanh.default, - aten.ceil.default, - aten.erf.default, - aten.erfinv.default, - aten.expm1.default, - aten.fix.default, - aten.floor.default, - aten.hardtanh.default, - aten.leaky_relu.default, - aten.log1p.default, - aten.neg.default, - aten.negative.default, - aten.positive.default, - aten.relu.default, - aten.round.default, - aten.sgn.default, - aten.sign.default, - aten.sin.default, - aten.sinh.default, - aten.sqrt.default, - aten.square.default, - aten.tan.default, - aten.tanh.default, - aten.trunc.default, -] - -_IN_PLACE_POINTWISE_FUNCTIONS = [ - aten.abs_.default, - aten.absolute_.default, - aten.asin_.default, - aten.asinh_.default, - aten.atan_.default, - aten.atanh_.default, - aten.ceil_.default, - aten.erf_.default, - aten.erfinv_.default, - aten.expm1_.default, - aten.fix_.default, - aten.floor_.default, - aten.hardtanh_.default, - aten.leaky_relu_.default, - aten.log1p_.default, - aten.neg_.default, - aten.negative_.default, - aten.relu_.default, - aten.round_.default, - aten.sgn_.default, - aten.sign_.default, - aten.sin_.default, - aten.sinh_.default, - aten.sqrt_.default, - aten.square_.default, - aten.tan_.default, - aten.tanh_.default, - aten.trunc_.default, -] - - -def _override_pointwise(op): - @DiagonalSparseTensor.implements(op) - def func_(t: DiagonalSparseTensor) -> DiagonalSparseTensor: - assert isinstance(t, DiagonalSparseTensor) - return DiagonalSparseTensor(op(t.contiguous_data), t.v_to_ps) - - return func_ - - -def _override_inplace_pointwise(op): - @DiagonalSparseTensor.implements(op) - def func_(t: DiagonalSparseTensor) -> DiagonalSparseTensor: - assert isinstance(t, DiagonalSparseTensor) - op(t.contiguous_data) - return t - - -for func in _POINTWISE_FUNCTIONS: - _override_pointwise(func) - -for func in _IN_PLACE_POINTWISE_FUNCTIONS: - _override_inplace_pointwise(func) - - @DiagonalSparseTensor.implements(aten.mean.default) def mean_default(t: DiagonalSparseTensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) @@ -575,3 +486,92 @@ def mm_default(mat1: Tensor, mat2: Tensor) -> DiagonalSparseTensor: mat2_ = to_diagonal_sparse_tensor(mat2) return einsum((mat1_, [0, 1]), (mat2_, [1, 2]), output=[0, 2]) + + +# pointwise functions applied to one Tensor with `0.0 → 0` +_POINTWISE_FUNCTIONS = [ + aten.abs.default, + aten.absolute.default, + aten.asin.default, + aten.asinh.default, + aten.atan.default, + aten.atanh.default, + aten.ceil.default, + aten.erf.default, + aten.erfinv.default, + aten.expm1.default, + aten.fix.default, + aten.floor.default, + aten.hardtanh.default, + aten.leaky_relu.default, + aten.log1p.default, + aten.neg.default, + aten.negative.default, + aten.positive.default, + aten.relu.default, + aten.round.default, + aten.sgn.default, + aten.sign.default, + aten.sin.default, + aten.sinh.default, + aten.sqrt.default, + aten.square.default, + aten.tan.default, + aten.tanh.default, + aten.trunc.default, +] + +_IN_PLACE_POINTWISE_FUNCTIONS = [ + aten.abs_.default, + aten.absolute_.default, + aten.asin_.default, + aten.asinh_.default, + aten.atan_.default, + aten.atanh_.default, + aten.ceil_.default, + aten.erf_.default, + aten.erfinv_.default, + aten.expm1_.default, + aten.fix_.default, + aten.floor_.default, + aten.hardtanh_.default, + aten.leaky_relu_.default, + aten.log1p_.default, + aten.neg_.default, + aten.negative_.default, + aten.relu_.default, + aten.round_.default, + aten.sgn_.default, + aten.sign_.default, + aten.sin_.default, + aten.sinh_.default, + aten.sqrt_.default, + aten.square_.default, + aten.tan_.default, + aten.tanh_.default, + aten.trunc_.default, +] + + +def _override_pointwise(op): + @DiagonalSparseTensor.implements(op) + def func_(t: DiagonalSparseTensor) -> DiagonalSparseTensor: + assert isinstance(t, DiagonalSparseTensor) + return DiagonalSparseTensor(op(t.contiguous_data), t.v_to_ps) + + return func_ + + +def _override_inplace_pointwise(op): + @DiagonalSparseTensor.implements(op) + def func_(t: DiagonalSparseTensor) -> DiagonalSparseTensor: + assert isinstance(t, DiagonalSparseTensor) + op(t.contiguous_data) + return t + + +for pointwise_func in _POINTWISE_FUNCTIONS: + _override_pointwise(pointwise_func) + +for pointwise_func in _IN_PLACE_POINTWISE_FUNCTIONS: + _override_inplace_pointwise(pointwise_func) From 866b4b180a45272dfd78e8a4bca955f6b0762a80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 03:08:13 +0100 Subject: [PATCH 073/182] Uniformize name of physical * {data, contiguous_data, physical_data} => physical --- src/torchjd/autogram/_engine.py | 4 +- src/torchjd/sparse/_diagonal_sparse_tensor.py | 120 +++++++++--------- .../sparse/test_diagonal_sparse_tensor.py | 16 +-- 3 files changed, 69 insertions(+), 71 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 0a4be030e..98156f354 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -175,9 +175,9 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: ) output_dims = list(range(output.ndim)) - physical_data = torch.ones_like(output).squeeze() + physical = torch.ones_like(output).squeeze() v_to_ps = [[dim] if output.shape[dim] != 1 else [] for dim in output_dims * 2] - jac_output = DiagonalSparseTensor(physical_data, v_to_ps) + jac_output = DiagonalSparseTensor(physical, v_to_ps) vmapped_diff = differentiation for _ in output_dims: diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 6676e141f..26f36be45 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -13,10 +13,10 @@ class DiagonalSparseTensor(torch.Tensor): _HANDLED_FUNCTIONS = dict() @staticmethod - def __new__(cls, data: Tensor, v_to_ps: list[list[int]]): + def __new__(cls, physical: Tensor, v_to_ps: list[list[int]]): # At the moment, this class is not compositional, so we assert # that the tensor we're wrapping is exactly a Tensor - assert type(data) is Tensor + assert type(physical) is Tensor # Note [Passing requires_grad=true tensors to subclasses] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -26,45 +26,47 @@ def __new__(cls, data: Tensor, v_to_ps: list[list[int]]): # representing the "constructor" (NegativeView, in this case) # and call that instead. This assert helps prevent direct usage # (which is bad!) - assert not data.requires_grad or not torch.is_grad_enabled() + assert not physical.requires_grad or not torch.is_grad_enabled() - shape = [prod(data.shape[i] for i in dims) for dims in v_to_ps] - return Tensor._make_wrapper_subclass(cls, shape, dtype=data.dtype, device=data.device) + shape = [prod(physical.shape[i] for i in dims) for dims in v_to_ps] + return Tensor._make_wrapper_subclass( + cls, shape, dtype=physical.dtype, device=physical.device + ) - def __init__(self, data: Tensor, v_to_ps: list[list[int]]): + def __init__(self, physical: Tensor, v_to_ps: list[list[int]]): """ - This constructor is made for specifying data and v_to_ps exactly. It should not modify it. + This constructor is made for specifying physical and v_to_ps exactly. It should not modify it. - For this reason, another constructor will be made to either modify the data / v_to_ps to + For this reason, another constructor will be made to either modify the physical / v_to_ps to simplify the result, or to create a dense tensor directly if it's already dense. It could also be responsible for sorting the first apparition of each physical dim in the flattened v_to_ps. """ - if any(s == 1 for s in data.shape): + if any(s == 1 for s in physical.shape): raise ValueError( - "Physical data must not contain any dimension of size 1. Found data.shape=" - f"{data.shape}." + "physical must not contain any dimension of size 1. Found physical.shape=" + f"{physical.shape}." ) - if not all(all(0 <= dim < data.ndim for dim in dims) for dims in v_to_ps): + if not all(all(0 <= dim < physical.ndim for dim in dims) for dims in v_to_ps): raise ValueError( - f"Elements in v_to_ps must map to dimensions in data. Found {v_to_ps}." + f"Elements in v_to_ps must map to dimensions in physical. Found {v_to_ps}." ) - if len(set().union(*[set(dims) for dims in v_to_ps])) != data.ndim: - raise ValueError("Every dimension in data must appear at least once in v_to_ps.") + if len(set().union(*[set(dims) for dims in v_to_ps])) != physical.ndim: + raise ValueError("Every dimension in physical must appear at least once in v_to_ps.") - self.contiguous_data = data # self.data cannot be used here. + self.physical = physical self.v_to_ps = v_to_ps # This is a list of strides whose shape matches that of v_to_ps except that each element # is the stride factor of the index to get the right element for the corresponding virtual # dimension. Stride is the jump necessary to go from one element to the next one in the # specified dimension. For instance if the i'th element of v_to_ps is [0, 1, 2], then the - # i'th element of _strides is [data.shape[1] * data.shape[2], data.shape[2], 1] and so, if - # we index dimension i with j=j_0 * stride[0] + j_1 * stride[1] + j_2 * stride[2], which is - # a unique decomposition, then this corresponds to indexing dimensions v_to_ps[i] at indices - # [j_0, j_1, j_2] - s = data.shape + # i'th element of _strides is [physical.shape[1] * physical.shape[2], physical.shape[2], 1] + # and so, if we index dimension i with j=j_0 * stride[0] + j_1 * stride[1] + j_2 * stride[2] + # which isa unique decomposition, then this corresponds to indexing dimensions v_to_ps[i] at + # indices [j_0, j_1, j_2] + s = physical.shape self._strides = [ list(accumulate([1] + [s[dim] for dim in dims[:0:-1]], operator.mul))[::-1] for dims in v_to_ps @@ -76,15 +78,13 @@ def to_dense( assert dtype is None # We may add support for this later assert masked_grad is None # We may add support for this later - if self.contiguous_data.ndim == 0: - return self.contiguous_data - p_index_ranges = [ - torch.arange(s, device=self.contiguous_data.device) for s in self.contiguous_data.shape - ] + if self.physical.ndim == 0: + return self.physical + p_index_ranges = [torch.arange(s, device=self.physical.device) for s in self.physical.shape] p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij") v_indices_grid = list[Tensor]() for stride, dims in zip(self._strides, self.v_to_ps): - stride_ = torch.tensor(stride, device=self.contiguous_data.device, dtype=torch.int) + stride_ = torch.tensor(stride, device=self.physical.device, dtype=torch.int) v_indices_grid.append( torch.sum(torch.stack([p_indices_grid[d] for d in dims], dim=-1) * stride_, dim=-1) ) @@ -92,10 +92,8 @@ def to_dense( # whose elements are the coordinates 1 in p_indices_grad[d_1] times stride 1 # plus coordinates 2 in p_indices_grad[d_2] times stride 2, etc... - res = torch.zeros( - self.shape, device=self.contiguous_data.device, dtype=self.contiguous_data.dtype - ) - res[tuple(v_indices_grid)] = self.contiguous_data + res = torch.zeros(self.shape, device=self.physical.device, dtype=self.physical.dtype) + res[tuple(v_indices_grid)] = self.physical return res def p_to_vs(self) -> list[list[tuple[int, int]]]: @@ -136,7 +134,7 @@ def unwrap_to_dense(t: Tensor): def __repr__(self): return ( - f"DiagonalSparseTensor(data={self.contiguous_data}, v_to_ps_map={self.v_to_ps}, shape=" + f"DiagonalSparseTensor(physical={self.physical}, v_to_ps_map={self.v_to_ps}, shape=" f"{self.shape})" ) @@ -145,8 +143,8 @@ def debug_info(self) -> str: f"shape: {self.shape}\n" f"stride(): {self.stride()}\n" f"v_to_ps: {self.v_to_ps}\n" - f"contiguous_data.shape: {self.contiguous_data.shape}\n" - f"contiguous_data.stride(): {self.contiguous_data.stride()}" + f"physical.shape: {self.physical.shape}\n" + f"physical.stride(): {self.physical.stride()}" ) return info @@ -204,13 +202,13 @@ def to_diagonal_sparse_tensor(t: Tensor) -> DiagonalSparseTensor: @DiagonalSparseTensor.implements(aten.mean.default) def mean_default(t: DiagonalSparseTensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) - return aten.sum.default(t.contiguous_data) / t.numel() + return aten.sum.default(t.physical) / t.numel() @DiagonalSparseTensor.implements(aten.sum.default) def sum_default(t: DiagonalSparseTensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) - return aten.sum.default(t.contiguous_data) + return aten.sum.default(t.physical) @DiagonalSparseTensor.implements(aten.pow.Tensor_Scalar) @@ -221,8 +219,8 @@ def pow_Tensor_Scalar(t: DiagonalSparseTensor, exponent: float) -> DiagonalSpars # Need to densify because we don't have pow(0.0, exponent) = 0.0 return aten.pow.Tensor_Scalar(t.to_dense(), exponent) - new_contiguous_data = aten.pow.Tensor_Scalar(t.contiguous_data, exponent) - return DiagonalSparseTensor(new_contiguous_data, t.v_to_ps) + new_physical = aten.pow.Tensor_Scalar(t.physical, exponent) + return DiagonalSparseTensor(new_physical, t.v_to_ps) # Somehow there's no pow_.Tensor_Scalar and pow_.Scalar takes tensor and scalar. @@ -235,7 +233,7 @@ def pow__Scalar(t: DiagonalSparseTensor, exponent: float) -> DiagonalSparseTenso # Note sure if it's even possible to densify in-place, so let's just raise an error. raise ValueError(f"in-place pow with an exponent of {exponent} (<= 0) is not supported.") - aten.pow_.Scalar(t.contiguous_data, exponent) + aten.pow_.Scalar(t.physical, exponent) return t @@ -250,13 +248,13 @@ def unsqueeze_default(t: DiagonalSparseTensor, dim: int) -> DiagonalSparseTensor new_v_to_ps = [p for p in t.v_to_ps] # Deepcopy the list to not modify the original v_to_ps new_v_to_ps.insert(dim, []) - return DiagonalSparseTensor(t.contiguous_data, new_v_to_ps) + return DiagonalSparseTensor(t.physical, new_v_to_ps) @DiagonalSparseTensor.implements(aten.view.default) def view_default(t: DiagonalSparseTensor, shape: list[int]) -> DiagonalSparseTensor: # TODO: add error message when error is raised - # TODO: handle case where the contiguous_data has to be reshaped + # TODO: handle case where the physical has to be reshaped assert isinstance(t, DiagonalSparseTensor) @@ -266,7 +264,7 @@ def view_default(t: DiagonalSparseTensor, shape: list[int]) -> DiagonalSparseTen new_v_to_ps = [] idx = 0 flat_v_to_ps = [dim for dims in t.v_to_ps for dim in dims] - p_shape = t.contiguous_data.shape + p_shape = t.physical.shape for s in shape: group = [] current_product = 1 @@ -287,7 +285,7 @@ def view_default(t: DiagonalSparseTensor, shape: list[int]) -> DiagonalSparseTen if idx != len(flat_v_to_ps): raise ValueError(f"idx != len(flat_v_to_ps). {idx}; {flat_v_to_ps}; {shape}; {t.v_to_ps}") - return DiagonalSparseTensor(t.contiguous_data, new_v_to_ps) + return DiagonalSparseTensor(t.physical, new_v_to_ps) @DiagonalSparseTensor.implements(aten.expand.default) @@ -302,7 +300,7 @@ def expand_default(t: DiagonalSparseTensor, sizes: list[int]) -> DiagonalSparseT assert len(sizes) == t.ndim - new_contiguous_data = t.contiguous_data + new_physical = t.physical new_v_to_ps = t.v_to_ps n_added_physical_dims = 0 for dim, (ps, orig_size, new_size) in enumerate(zip(t.v_to_ps, t.shape, sizes, strict=True)): @@ -314,23 +312,23 @@ def expand_default(t: DiagonalSparseTensor, sizes: list[int]) -> DiagonalSparseT if len(ps) == 0 and new_size != 1 and new_size != -1: # Add a dimension of size new_size at the end of the physical tensor. - new_physical_shape = list(new_contiguous_data.shape) + [new_size] - new_contiguous_data = new_contiguous_data.unsqueeze(-1).expand(new_physical_shape) - new_v_to_ps[dim] = [t.contiguous_data.ndim + n_added_physical_dims] + new_physical_shape = list(new_physical.shape) + [new_size] + new_physical = new_physical.unsqueeze(-1).expand(new_physical_shape) + new_v_to_ps[dim] = [t.physical.ndim + n_added_physical_dims] n_added_physical_dims += 1 new_v_to_ps, destination = first_sort_v_to_ps(new_v_to_ps) - new_contiguous_data = new_contiguous_data.movedim(list(range(len(destination))), destination) + new_physical = new_physical.movedim(list(range(len(destination))), destination) - return DiagonalSparseTensor(new_contiguous_data, new_v_to_ps) + return DiagonalSparseTensor(new_physical, new_v_to_ps) @DiagonalSparseTensor.implements(aten.div.Scalar) def div_Scalar(t: DiagonalSparseTensor, divisor: float) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) - new_contiguous_data = aten.div.Scalar(t.contiguous_data, divisor) - return DiagonalSparseTensor(new_contiguous_data, t.v_to_ps) + new_physical = aten.div.Scalar(t.physical, divisor) + return DiagonalSparseTensor(new_physical, t.v_to_ps) @DiagonalSparseTensor.implements(aten.slice.Tensor) @@ -346,9 +344,9 @@ def slice_Tensor( physical_dim = physical_dims[0] - new_contiguous_data = aten.slice.Tensor(t.contiguous_data, physical_dim, start, end, step) + new_physical = aten.slice.Tensor(t.physical, physical_dim, start, end, step) - return DiagonalSparseTensor(new_contiguous_data, t.v_to_ps) + return DiagonalSparseTensor(new_physical, t.v_to_ps) @DiagonalSparseTensor.implements(aten.mul.Tensor) @@ -356,8 +354,8 @@ def mul_Tensor(t1: Tensor, t2: Tensor) -> DiagonalSparseTensor: # Element-wise multiplication assert isinstance(t1, DiagonalSparseTensor) or isinstance(t2, DiagonalSparseTensor) - new_contiguous_data = aten.mul.Tensor(t1, t2.contiguous_data) - return DiagonalSparseTensor(new_contiguous_data, t2.v_to_ps) + new_physical = aten.mul.Tensor(t1, t2.physical) + return DiagonalSparseTensor(new_physical, t2.v_to_ps) @DiagonalSparseTensor.implements(aten.transpose.int) @@ -368,7 +366,7 @@ def transpose_int(t: DiagonalSparseTensor, dim0: int, dim1: int) -> DiagonalSpar new_v_to_ps[dim0] = t.v_to_ps[dim1] new_v_to_ps[dim1] = t.v_to_ps[dim0] - return DiagonalSparseTensor(t.contiguous_data, new_v_to_ps) + return DiagonalSparseTensor(t.physical, new_v_to_ps) def einsum( @@ -412,7 +410,7 @@ def group_indices(indices: list[tuple[int, int]]) -> None: indices_to_n_pdims = dict[int, int]() for t, indices in args: assert isinstance(t, DiagonalSparseTensor) - tensors.append(t.contiguous_data) + tensors.append(t.physical) for ps, index in zip(t.v_to_ps, indices): if index in indices_to_n_pdims: assert indices_to_n_pdims[index] == len(ps) @@ -455,8 +453,8 @@ def unique_int(pair: tuple[int, int]) -> int: new_output.append(k) v_to_ps.append(current_v_to_ps) - data = torch.einsum(*[x for y in zip(tensors, new_indices) for x in y], new_output) - return DiagonalSparseTensor(data, v_to_ps) + physical = torch.einsum(*[x for y in zip(tensors, new_indices) for x in y], new_output) + return DiagonalSparseTensor(physical, v_to_ps) @DiagonalSparseTensor.implements(aten.bmm.default) @@ -557,7 +555,7 @@ def _override_pointwise(op): @DiagonalSparseTensor.implements(op) def func_(t: DiagonalSparseTensor) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) - return DiagonalSparseTensor(op(t.contiguous_data), t.v_to_ps) + return DiagonalSparseTensor(op(t.physical), t.v_to_ps) return func_ @@ -566,7 +564,7 @@ def _override_inplace_pointwise(op): @DiagonalSparseTensor.implements(op) def func_(t: DiagonalSparseTensor) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) - op(t.contiguous_data) + op(t.physical) return t diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index 2b1fc3a51..c61928fd9 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -109,7 +109,7 @@ def test_unary(func): @mark.parametrize( - ["data_shape", "v_to_ps", "target_shape"], + ["physical_shape", "v_to_ps", "target_shape"], [ ([2, 3], [[0], [0], [1]], [2, 2, 3]), # no change of shape ([2, 3], [[0], [0, 1]], [2, 6]), # no change of shape @@ -126,8 +126,8 @@ def test_unary(func): ([2, 3, 4], [[0], [0], [1], [2]], [4, 12]), # world boss ], ) -def test_view(data_shape: list[int], v_to_ps: list[list[int]], target_shape: list[int]): - a = randn_(tuple(data_shape)) +def test_view(physical_shape: list[int], v_to_ps: list[list[int]], target_shape: list[int]): + a = randn_(tuple(physical_shape)) t = DiagonalSparseTensor(a, v_to_ps) result = aten.view.default(t, target_shape) @@ -138,7 +138,7 @@ def test_view(data_shape: list[int], v_to_ps: list[list[int]], target_shape: lis @mark.parametrize( - ["data_shape", "v_to_ps", "target_shape", "expected_data_shape", "expected_v_to_ps"], + ["physical_shape", "v_to_ps", "target_shape", "expected_physical_shape", "expected_v_to_ps"], [ ([2, 3], [[0], [0], [1]], [2, 2, 3], [2, 3], [[0], [0], [1]]), # no change of shape ([2, 3], [[0], [0, 1]], [2, 6], [2, 3], [[0], [0, 1]]), # no change of shape @@ -165,20 +165,20 @@ def test_view(data_shape: list[int], v_to_ps: list[list[int]], target_shape: lis ], ) def test_view2( - data_shape: list[int], + physical_shape: list[int], v_to_ps: list[list[int]], target_shape: list[int], - expected_data_shape: list[int], + expected_physical_shape: list[int], expected_v_to_ps: list[list[int]], ): - a = randn_(tuple(data_shape)) + a = randn_(tuple(physical_shape)) t = DiagonalSparseTensor(a, v_to_ps) result = aten.view.default(t, target_shape) expected = t.to_dense().reshape(target_shape) assert isinstance(result, DiagonalSparseTensor) - assert list(result.contiguous_data.shape) == expected_data_shape + assert list(result.physical.shape) == expected_physical_shape assert result.v_to_ps == expected_v_to_ps assert torch.all(torch.eq(result.to_dense(), expected)) From b73de79144785fe294503bdfc6544342b0bd50ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 03:15:14 +0100 Subject: [PATCH 074/182] Improve print when falling back to dense --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 26f36be45..4a00a3ed1 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -124,12 +124,19 @@ def unwrap_to_dense(t: Tensor): else: return t - print(f"Falling back to dense for {func.__name__} called with the following args:") - for arg in args: - print(arg) - print("and the following kwargs:") - for k, v in kwargs.items(): - print(f"{k}: {v}") + print(f"Falling back to dense for {func.__name__}") + if len(args) > 0: + print("* args:") + for arg in args: + if isinstance(arg, Tensor): + print(f" > {arg.__class__.__name__} - {arg.shape}") + else: + print(f" > {arg}") + if len(kwargs) > 0: + print("* kwargs:") + for k, v in kwargs.items(): + print(f" > {k}: {v}") + print() return func(*tree_map(unwrap_to_dense, args), **tree_map(unwrap_to_dense, kwargs)) def __repr__(self): From e393096b020d131a1192e077a6c2ad8938da3cac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 03:37:16 +0100 Subject: [PATCH 075/182] Rename first_sort to encode_by_order and improve its docstring --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 19 +++++++++++-------- .../sparse/test_diagonal_sparse_tensor.py | 6 +++--- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 4a00a3ed1..880fc1d2d 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -167,12 +167,15 @@ def decorator(func): return decorator -def first_sort(input: list[int]) -> tuple[list[int], list[int]]: +def encode_by_order(input: list[int]) -> tuple[list[int], list[int]]: """ - Sorts a list of ints so that the first element to appear for the first time is 0, the second is - 1, etc. Elements may appear anywhere after their first appearance. Returns the sorted list and - list corresponding to the destination of each original int. destination[i] = j means that - all elements of value i in input are mapping to j in sorted list. + Encodes values based on the order of their first appearance, starting at 0 and incrementing. + + Returns the encoded list and the destination mapping each original int to its new encoding. + destination[i] = j means that all elements of value i in input are mapped to j in the encoded + list. + + The input list should only contain consecutive integers starting at 0. Examples: [1, 0, 3, 2] => [0, 1, 2, 3], [1, 0, 3, 2] @@ -193,9 +196,9 @@ def first_sort(input: list[int]) -> tuple[list[int], list[int]]: return output, destination -def first_sort_v_to_ps(v_to_ps: list[list[int]]) -> tuple[list[list[int]], list[int]]: +def encode_v_to_ps(v_to_ps: list[list[int]]) -> tuple[list[list[int]], list[int]]: flat_v_to_ps, spec = tree_flatten(v_to_ps) - sorted_flat_v_to_ps, destination = first_sort(flat_v_to_ps) + sorted_flat_v_to_ps, destination = encode_by_order(flat_v_to_ps) return tree_unflatten(sorted_flat_v_to_ps, spec), destination @@ -324,7 +327,7 @@ def expand_default(t: DiagonalSparseTensor, sizes: list[int]) -> DiagonalSparseT new_v_to_ps[dim] = [t.physical.ndim + n_added_physical_dims] n_added_physical_dims += 1 - new_v_to_ps, destination = first_sort_v_to_ps(new_v_to_ps) + new_v_to_ps, destination = encode_v_to_ps(new_v_to_ps) new_physical = new_physical.movedim(list(range(len(destination))), destination) return DiagonalSparseTensor(new_physical, new_v_to_ps) diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index c61928fd9..8c2ae0fda 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -9,7 +9,7 @@ _POINTWISE_FUNCTIONS, DiagonalSparseTensor, einsum, - first_sort, + encode_by_order, ) @@ -193,12 +193,12 @@ def test_view2( ([1, 0, 0, 1], [0, 1, 1, 0], [1, 0]), ], ) -def test_first_sort( +def test_encode_by_order( input: list[int], expected_output: list[int], expected_destination: list[int], ): - output, destination = first_sort(input) + output, destination = encode_by_order(input) assert output == expected_output assert destination == expected_destination From 7a8b8a40fc73cf969f1d4326c7a8ffa934d50d68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 03:42:16 +0100 Subject: [PATCH 076/182] Improve repr of DST * It's not supposed to contain extra info that are not argument of its init, like shape * We can use debug_info to get the shape. * Make its signature match that of Tensor --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 880fc1d2d..e29fb215e 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -139,11 +139,8 @@ def unwrap_to_dense(t: Tensor): print() return func(*tree_map(unwrap_to_dense, args), **tree_map(unwrap_to_dense, kwargs)) - def __repr__(self): - return ( - f"DiagonalSparseTensor(physical={self.physical}, v_to_ps_map={self.v_to_ps}, shape=" - f"{self.shape})" - ) + def __repr__(self, *, tensor_contents=None) -> str: + return f"DiagonalSparseTensor(physical={self.physical}, v_to_ps={self.v_to_ps})" def debug_info(self) -> str: info = ( From 28c72ed98cc287fd33b2d4c30b2a327a5e9c5c2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 03:48:58 +0100 Subject: [PATCH 077/182] Minor formatting fix --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index e29fb215e..993cb4bb9 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -384,7 +384,7 @@ def einsum( # [p_1, ..., p_k], then we have to create fresh sub-indices for each dimension. # For this reason, an index is decomposed into sub-indices that are then independently # clustered. - # So if an index i in args for some DiagonalSparseTensor corresponds to a v_to_ps [j, k, l], then + # So if an index i in args for some DiagonalSparseTensor corresponds to a v_to_ps [j, k, l], # We will consider three indices (i, 0), (i, 1) and (i, 2). # If furthermore [k] correspond to the v_to_ps of some other tensor with index j, then # (i, 1) and (j, 0) will be clustered together (and end up being mapped to the same indice in From 1c1885c33344f59a86ed4afb0c54d33f2a903497 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 04:57:31 +0100 Subject: [PATCH 078/182] Move _strides to where it is used: - It's only used in to_dense, and we should only call to_dense maximum one time, so it's a bit of a waste to compute the _strides all the time an object is instantiated. --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 993cb4bb9..9560325a9 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -58,6 +58,15 @@ def __init__(self, physical: Tensor, v_to_ps: list[list[int]]): self.physical = physical self.v_to_ps = v_to_ps + def to_dense( + self, dtype: torch.dtype | None = None, *, masked_grad: bool | None = None + ) -> Tensor: + assert dtype is None # We may add support for this later + assert masked_grad is None # We may add support for this later + + if self.physical.ndim == 0: + return self.physical + # This is a list of strides whose shape matches that of v_to_ps except that each element # is the stride factor of the index to get the right element for the corresponding virtual # dimension. Stride is the jump necessary to go from one element to the next one in the @@ -66,24 +75,16 @@ def __init__(self, physical: Tensor, v_to_ps: list[list[int]]): # and so, if we index dimension i with j=j_0 * stride[0] + j_1 * stride[1] + j_2 * stride[2] # which isa unique decomposition, then this corresponds to indexing dimensions v_to_ps[i] at # indices [j_0, j_1, j_2] - s = physical.shape - self._strides = [ + s = self.physical.shape + strides = [ list(accumulate([1] + [s[dim] for dim in dims[:0:-1]], operator.mul))[::-1] - for dims in v_to_ps + for dims in self.v_to_ps ] - def to_dense( - self, dtype: torch.dtype | None = None, *, masked_grad: bool | None = None - ) -> Tensor: - assert dtype is None # We may add support for this later - assert masked_grad is None # We may add support for this later - - if self.physical.ndim == 0: - return self.physical p_index_ranges = [torch.arange(s, device=self.physical.device) for s in self.physical.shape] p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij") v_indices_grid = list[Tensor]() - for stride, dims in zip(self._strides, self.v_to_ps): + for stride, dims in zip(strides, self.v_to_ps): stride_ = torch.tensor(stride, device=self.physical.device, dtype=torch.int) v_indices_grid.append( torch.sum(torch.stack([p_indices_grid[d] for d in dims], dim=-1) * stride_, dim=-1) From 09f3efc1f4fb3656f31dfffafdd534b21f0c0c4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 04:58:22 +0100 Subject: [PATCH 079/182] Minor reformating --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 9560325a9..08c4420b6 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -35,7 +35,8 @@ def __new__(cls, physical: Tensor, v_to_ps: list[list[int]]): def __init__(self, physical: Tensor, v_to_ps: list[list[int]]): """ - This constructor is made for specifying physical and v_to_ps exactly. It should not modify it. + This constructor is made for specifying physical and v_to_ps exactly. It should not modify + it. For this reason, another constructor will be made to either modify the physical / v_to_ps to simplify the result, or to create a dense tensor directly if it's already dense. It could From 5703fd22a35eed66e8cfb3fbf32a0d8888cbdd6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 05:06:28 +0100 Subject: [PATCH 080/182] Add check that v_to_ps are correctly encoded * Also update test_einsum to fix encoding --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 5 +++++ tests/unit/sparse/test_diagonal_sparse_tensor.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 08c4420b6..194057bc9 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -56,6 +56,11 @@ def __init__(self, physical: Tensor, v_to_ps: list[list[int]]): if len(set().union(*[set(dims) for dims in v_to_ps])) != physical.ndim: raise ValueError("Every dimension in physical must appear at least once in v_to_ps.") + if v_to_ps != encode_v_to_ps(v_to_ps)[0]: + raise ValueError( + f"v_to_ps elements are not encoded by first appearance. Found {v_to_ps}." + ) + self.physical = physical self.v_to_ps = v_to_ps diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index 8c2ae0fda..47eb87edb 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -27,7 +27,7 @@ def test_to_dense(): def test_einsum(): a = DiagonalSparseTensor(torch.randn([4, 5]), [[0], [0], [1]]) - b = DiagonalSparseTensor(torch.randn([5, 4]), [[1], [0], [0]]) + b = DiagonalSparseTensor(torch.randn([4, 5]), [[0], [1], [1]]) res = einsum((a, [0, 1, 2]), (b, [0, 2, 3]), output=[0, 1, 3]) From 8a13d498132331e595661f823c77f6c497bc2f6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 05:09:20 +0100 Subject: [PATCH 081/182] Fix to_diagonal_sparse_tensor to not create physical with dim of size 1. --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 194057bc9..01394b8ad 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -210,7 +210,9 @@ def to_diagonal_sparse_tensor(t: Tensor) -> DiagonalSparseTensor: if isinstance(t, DiagonalSparseTensor): return t else: - return DiagonalSparseTensor(t, [[i] for i in range(t.ndim)]) + physical = t.squeeze() # Remove all dimensions of size 1 + v_to_ps = [[i] if t.shape[i] != 1 else [] for i in range(t.ndim)] + return DiagonalSparseTensor(physical, v_to_ps) @DiagonalSparseTensor.implements(aten.mean.default) From 82d4b3314d4b51e217f6c44338be4c594f514825 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 05:44:11 +0100 Subject: [PATCH 082/182] Add function to create a DST without having to care about dims of size 1 or incorrect encoding of v_to_ps. * Name is temporary, at least it's easy to refactor compared to diagonal_sparse_tensor which is also the name of the python file. --- src/torchjd/autogram/_engine.py | 7 ++-- src/torchjd/sparse/__init__.py | 2 +- src/torchjd/sparse/_diagonal_sparse_tensor.py | 33 +++++++++++++++++-- 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 98156f354..89815713c 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -4,7 +4,7 @@ from torch import Tensor, nn, vmap from torch.autograd.graph import get_gradient_edge -from torchjd.sparse import DiagonalSparseTensor +from torchjd.sparse import make_dst from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator @@ -175,9 +175,8 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: ) output_dims = list(range(output.ndim)) - physical = torch.ones_like(output).squeeze() - v_to_ps = [[dim] if output.shape[dim] != 1 else [] for dim in output_dims * 2] - jac_output = DiagonalSparseTensor(physical, v_to_ps) + v_to_ps = [[dim] for dim in output_dims * 2] + jac_output = make_dst(torch.ones_like(output), v_to_ps) vmapped_diff = differentiation for _ in output_dims: diff --git a/src/torchjd/sparse/__init__.py b/src/torchjd/sparse/__init__.py index b54263e55..8350adff9 100644 --- a/src/torchjd/sparse/__init__.py +++ b/src/torchjd/sparse/__init__.py @@ -1 +1 @@ -from ._diagonal_sparse_tensor import DiagonalSparseTensor +from ._diagonal_sparse_tensor import DiagonalSparseTensor, make_dst diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 01394b8ad..6fad3c300 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -210,9 +210,36 @@ def to_diagonal_sparse_tensor(t: Tensor) -> DiagonalSparseTensor: if isinstance(t, DiagonalSparseTensor): return t else: - physical = t.squeeze() # Remove all dimensions of size 1 - v_to_ps = [[i] if t.shape[i] != 1 else [] for i in range(t.ndim)] - return DiagonalSparseTensor(physical, v_to_ps) + return make_dst(t, [[i] for i in range(t.ndim)]) + + +def fix_dim_encoding(physical: Tensor, v_to_ps: list[list[int]]) -> tuple[Tensor, list[list[int]]]: + v_to_ps, destination = encode_v_to_ps(v_to_ps) + source = list(range(physical.ndim)) + physical = physical.movedim(source, destination) + + return physical, v_to_ps + + +def fix_dim_of_size_1(physical: Tensor, v_to_ps: list[list[int]]) -> tuple[Tensor, list[list[int]]]: + is_of_size_1 = [s == 1 for s in physical.shape] + + def new_encoding(d: int) -> int: + n_removed_dims_before_d = sum(is_of_size_1[:d]) + return d - n_removed_dims_before_d + + physical = physical.squeeze() + v_to_ps = [[new_encoding(d) for d in dims if not is_of_size_1[d]] for dims in v_to_ps] + + return physical, v_to_ps + + +def make_dst(physical: Tensor, v_to_ps: list[list[int]]) -> DiagonalSparseTensor: + """Fix physical and v_to_ps and create a DiagonalSparseTensor with them.""" + + physical, v_to_ps = fix_dim_encoding(physical, v_to_ps) + physical, v_to_ps = fix_dim_of_size_1(physical, v_to_ps) + return DiagonalSparseTensor(physical, v_to_ps) @DiagonalSparseTensor.implements(aten.mean.default) From 9d21a6c42046bfc1e7708e1fe85736dbba5b6a6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 05:45:34 +0100 Subject: [PATCH 083/182] Remove test_view and rename test_view2 to test_view --- .../sparse/test_diagonal_sparse_tensor.py | 31 +------------------ 1 file changed, 1 insertion(+), 30 deletions(-) diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index 47eb87edb..a5f773167 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -108,35 +108,6 @@ def test_unary(func): assert_close(res.to_dense(), func(c)) -@mark.parametrize( - ["physical_shape", "v_to_ps", "target_shape"], - [ - ([2, 3], [[0], [0], [1]], [2, 2, 3]), # no change of shape - ([2, 3], [[0], [0, 1]], [2, 6]), # no change of shape - ([2, 3], [[0], [0], [1]], [2, 6]), # squashing 2 dimensions - ([2, 3], [[0], [0, 1]], [2, 2, 3]), # unsquashing into 2 dimensions - ([2, 3], [[0, 0, 1]], [2, 6]), # unsquashing into 2 dimensions - ([2, 3], [[0], [0], [1]], [12]), # squashing 3 dimensions - ([2, 3], [[0, 0, 1]], [2, 2, 3]), # unsquashing into 3 dimensions - ( - [4], - [[0], [0]], - [2, 2, 4], - ), # unsquashing into 2 dimensions, need to split physical dimension - ([2, 3, 4], [[0], [0], [1], [2]], [4, 12]), # world boss - ], -) -def test_view(physical_shape: list[int], v_to_ps: list[list[int]], target_shape: list[int]): - a = randn_(tuple(physical_shape)) - t = DiagonalSparseTensor(a, v_to_ps) - - result = aten.view.default(t, target_shape) - expected = t.to_dense().reshape(target_shape) - - assert isinstance(result, DiagonalSparseTensor) - assert torch.all(torch.eq(result.to_dense(), expected)) - - @mark.parametrize( ["physical_shape", "v_to_ps", "target_shape", "expected_physical_shape", "expected_v_to_ps"], [ @@ -164,7 +135,7 @@ def test_view(physical_shape: list[int], v_to_ps: list[list[int]], target_shape: ([2, 12], [[0, 0], [1]], [2, 2, 3, 4], [2, 3, 4], [[0], [0], [1], [2]]), # world boss ], ) -def test_view2( +def test_view( physical_shape: list[int], v_to_ps: list[list[int]], target_shape: list[int], From 42e822851cf0eb9cd36571a9000c2e58c7184042 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 06:27:11 +0100 Subject: [PATCH 084/182] Minor reformatting --- .../sparse/test_diagonal_sparse_tensor.py | 24 +++++-------------- 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index a5f773167..535416a6e 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -113,24 +113,12 @@ def test_unary(func): [ ([2, 3], [[0], [0], [1]], [2, 2, 3], [2, 3], [[0], [0], [1]]), # no change of shape ([2, 3], [[0], [0, 1]], [2, 6], [2, 3], [[0], [0, 1]]), # no change of shape - ([2, 3], [[0], [0], [1]], [2, 6], [2, 3], [[0], [0, 1]]), # squashing 2 dimensions - ( - [2, 3], - [[0], [0, 1]], - [2, 2, 3], - [2, 3], - [[0], [0], [1]], - ), # unsquashing into 2 dimensions - ([2, 3], [[0, 0, 1]], [2, 6], [2, 3], [[0], [0, 1]]), # unsquashing into 2 dimensions - ([2, 3], [[0], [0], [1]], [12], [2, 3], [[0, 0, 1]]), # squashing 3 dimensions - ([2, 3], [[0, 0, 1]], [2, 2, 3], [2, 3], [[0], [0], [1]]), # unsquashing into 3 dimensions - ( - [4], - [[0], [0]], - [2, 2, 4], - [2, 2], - [[0], [1], [0, 1]], - ), # unsquashing into 2 dimensions, need to split physical dimension + ([2, 3], [[0], [0], [1]], [2, 6], [2, 3], [[0], [0, 1]]), # squashing 2 dims + ([2, 3], [[0], [0, 1]], [2, 2, 3], [2, 3], [[0], [0], [1]]), # unsquashing into 2 dims + ([2, 3], [[0, 0, 1]], [2, 6], [2, 3], [[0], [0, 1]]), # unsquashing into 2 dims + ([2, 3], [[0], [0], [1]], [12], [2, 3], [[0, 0, 1]]), # squashing 3 dims + ([2, 3], [[0, 0, 1]], [2, 2, 3], [2, 3], [[0], [0], [1]]), # unsquashing into 3 dims + ([4], [[0], [0]], [2, 2, 4], [2, 2], [[0], [1], [0, 1]]), # unsquashing physical dim ([2, 3, 4], [[0], [0], [1], [2]], [4, 12], [2, 12], [[0, 0], [1]]), # world boss ([2, 12], [[0, 0], [1]], [2, 2, 3, 4], [2, 3, 4], [[0], [0], [1], [2]]), # world boss ], From 33194c04ff5187e0ba897ea072e37ea75815944e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 07:15:02 +0100 Subject: [PATCH 085/182] Move p_to_vs outside of DiagonalSparseTensor and rename it p_to_vs_from_v_to_ps --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 6fad3c300..268e44a8c 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -103,20 +103,6 @@ def to_dense( res[tuple(v_indices_grid)] = self.physical return res - def p_to_vs(self) -> list[list[tuple[int, int]]]: - """ - A physical dimension is mapped to a list of couples of the form - (virtual_dim, sub_index_in_virtual_dim) - """ - res = dict[int, list[tuple[int, int]]]() - for v_dim, p_dims in enumerate(self.v_to_ps): - for i, p_dim in enumerate(p_dims): - if p_dim not in res: - res[p_dim] = [(v_dim, i)] - else: - res[p_dim].append((v_dim, i)) - return [res[i] for i in range(len(res))] - @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): kwargs = {} if kwargs is None else kwargs @@ -171,6 +157,22 @@ def decorator(func): return decorator +def p_to_vs_from_v_to_ps(v_to_ps: list[list[int]]) -> list[list[tuple[int, int]]]: + """ + A physical dimension is mapped to a list of couples of the form + (virtual_dim, sub_index_in_virtual_dim) + """ + + res = dict[int, list[tuple[int, int]]]() + for v_dim, p_dims in enumerate(v_to_ps): + for i, p_dim in enumerate(p_dims): + if p_dim not in res: + res[p_dim] = [(v_dim, i)] + else: + res[p_dim].append((v_dim, i)) + return [res[i] for i in range(len(res))] + + def encode_by_order(input: list[int]) -> tuple[list[int], list[int]]: """ Encodes values based on the order of their first appearance, starting at 0 and incrementing. @@ -459,7 +461,7 @@ def group_indices(indices: list[tuple[int, int]]) -> None: assert indices_to_n_pdims[index] == len(ps) else: indices_to_n_pdims[index] = len(ps) - p_to_vs = t.p_to_vs() + p_to_vs = p_to_vs_from_v_to_ps(t.v_to_ps) for indices_ in p_to_vs: # elements in indices[indices_] map to the same dimension, they should be clustered # together From 21ca8e3d9f3bc391c797bbed614ee20f4a8303e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 09:13:21 +0100 Subject: [PATCH 086/182] Add get_groupings, fix_ungrouped_dims, and use it in constructor --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 46 +++++++++++++++++++ .../sparse/test_diagonal_sparse_tensor.py | 34 ++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 268e44a8c..d22703553 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -173,6 +173,40 @@ def p_to_vs_from_v_to_ps(v_to_ps: list[list[int]]) -> list[list[tuple[int, int]] return [res[i] for i in range(len(res))] +def get_groupings(v_to_ps: list[list[int]]) -> list[list[int]]: + """Example: [[0, 1, 2], [2, 0, 1], [2]] => [[0, 1], [2]]""" + + mapping = dict[int, list[int]]() + + for p_dims in v_to_ps: + for i, p_dim in enumerate(p_dims): + if p_dim not in mapping: + mapping[p_dim] = p_dims[i:] + else: + mapping[p_dim] = longest_common_prefix(mapping[p_dim], p_dims[i:]) + + groups = [] + visited_is = set() + for i, group in mapping.items(): + if i in visited_is: + continue + + groups.append(group) + visited_is.update(set(group)) + + return groups + + +def longest_common_prefix(l1: list[int], l2: list[int]) -> list[int]: + prefix = [] + for a, b in zip(l1, l2, strict=False): + if a == b: + prefix.append(a) + else: + break + return prefix + + def encode_by_order(input: list[int]) -> tuple[list[int], list[int]]: """ Encodes values based on the order of their first appearance, starting at 0 and incrementing. @@ -236,11 +270,23 @@ def new_encoding(d: int) -> int: return physical, v_to_ps +def fix_ungrouped_dims( + physical: Tensor, v_to_ps: list[list[int]] +) -> tuple[Tensor, list[list[int]]]: + groups = get_groupings(v_to_ps) + physical = physical.reshape([prod([physical.shape[dim] for dim in group]) for group in groups]) + mapping = {group[0]: i for i, group in enumerate(groups)} + new_v_to_ps = [[mapping[i] for i in dims if i in mapping] for dims in v_to_ps] + + return physical, new_v_to_ps + + def make_dst(physical: Tensor, v_to_ps: list[list[int]]) -> DiagonalSparseTensor: """Fix physical and v_to_ps and create a DiagonalSparseTensor with them.""" physical, v_to_ps = fix_dim_encoding(physical, v_to_ps) physical, v_to_ps = fix_dim_of_size_1(physical, v_to_ps) + physical, v_to_ps = fix_ungrouped_dims(physical, v_to_ps) return DiagonalSparseTensor(physical, v_to_ps) diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index 535416a6e..5f992f0d2 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -10,6 +10,8 @@ DiagonalSparseTensor, einsum, encode_by_order, + fix_ungrouped_dims, + get_groupings, ) @@ -161,3 +163,35 @@ def test_encode_by_order( assert output == expected_output assert destination == expected_destination + + +@mark.parametrize( + ["v_to_ps", "expected_groupings"], + [ + ([[0, 1, 2], [2, 0, 1], [2]], [[0, 1], [2]]), + ], +) +def test_get_groupings(v_to_ps: list[list[int]], expected_groupings: list[list[int]]): + groupings = get_groupings(v_to_ps) + print(groupings) + + assert groupings == expected_groupings + + +@mark.parametrize( + ["physical_shape", "v_to_ps", "expected_physical_shape", "expected_v_to_ps"], + [ + ([3, 4, 5], [[0, 1, 2], [2, 0, 1], [2]], [12, 5], [[0, 1], [1, 0], [1]]), + ], +) +def test_fix_ungrouped_dims( + physical_shape: list[int], + v_to_ps: list[list[int]], + expected_physical_shape: list[int], + expected_v_to_ps: list[list[int]], +): + physical = torch.randn(physical_shape) + fixed_physical, fixed_v_to_ps = fix_ungrouped_dims(physical, v_to_ps) + + assert list(fixed_physical.shape) == expected_physical_shape + assert fixed_v_to_ps == expected_v_to_ps From 15a615f89deb4c79229ceadc78a2cb42816d878f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 28 Oct 2025 09:16:24 +0100 Subject: [PATCH 087/182] Add check of maximal grouping in DST.__init__ * Note that this check sometimes fails. Need to fix the cause. --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index d22703553..0a22d4dec 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -61,6 +61,9 @@ def __init__(self, physical: Tensor, v_to_ps: list[list[int]]): f"v_to_ps elements are not encoded by first appearance. Found {v_to_ps}." ) + if any(len(group) != 1 for group in get_groupings(v_to_ps)): + raise ValueError(f"Dimensions must be maximally grouped. Found {v_to_ps}.") + self.physical = physical self.v_to_ps = v_to_ps From ee8e1815a9458453920248df2452531155a086b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 00:31:00 +0100 Subject: [PATCH 088/182] Use make_dst in einsum --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 0a22d4dec..71ec64c7d 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -548,7 +548,9 @@ def unique_int(pair: tuple[int, int]) -> int: v_to_ps.append(current_v_to_ps) physical = torch.einsum(*[x for y in zip(tensors, new_indices) for x in y], new_output) - return DiagonalSparseTensor(physical, v_to_ps) + # Need to use the safe constructor, otherwise the dimensions may not be maximally grouped. + # Maybe there is a way to fix that though. + return make_dst(physical, v_to_ps) @DiagonalSparseTensor.implements(aten.bmm.default) From 8992d3f5eccef158ab080af0119371b79e7e4c61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 00:44:39 +0100 Subject: [PATCH 089/182] Add possibility to slice dimension of size 1. --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 71ec64c7d..c322bb6bd 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -433,12 +433,30 @@ def slice_Tensor( physical_dims = t.v_to_ps[dim] - if len(physical_dims) != 1: - raise ValueError("Cannot yet slice virtual dim corresponding to several physical dims.") - - physical_dim = physical_dims[0] - - new_physical = aten.slice.Tensor(t.physical, physical_dim, start, end, step) + if len(physical_dims) > 1: + raise ValueError( + "Cannot yet slice virtual dim corresponding to several physical dims.\n" + f"{t.debug_info()}\n" + f"dim={dim}, start={start}, end={end}, step={step}." + ) + elif len(physical_dims) == 0: + # Trying to slice a virtual dim of size 1. + # Either + # - the element of this dim is included in the slice: keep it as it is + # - it's not included in the slice (e.g. end<=start): we would end up with a size of 0 on + # that dimension, so we'd need to add a dimension of size 0 to the physical. This is not + # implemented yet. + start_ = start if start is not None else 0 + end_ = end if end is not None else 1 + if end_ <= start_: # TODO: the condition might be a bit more complex if step != 1 + raise NotImplementedError( + "Slicing of dimension of size 1 leading to dimension of size 0 not implemented yet." + ) + else: + new_physical = t.physical + else: + physical_dim = physical_dims[0] + new_physical = aten.slice.Tensor(t.physical, physical_dim, start, end, step) return DiagonalSparseTensor(new_physical, t.v_to_ps) From cf4f950e413b3297fc9e553fa827e882845eb1a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 00:57:37 +0100 Subject: [PATCH 090/182] Fix to_dense when a pdims is [] --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index c322bb6bd..e6133f680 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -95,9 +95,15 @@ def to_dense( v_indices_grid = list[Tensor]() for stride, dims in zip(strides, self.v_to_ps): stride_ = torch.tensor(stride, device=self.physical.device, dtype=torch.int) - v_indices_grid.append( - torch.sum(torch.stack([p_indices_grid[d] for d in dims], dim=-1) * stride_, dim=-1) - ) + + if len(dims) > 0: + v_indices_grid.append( + torch.sum( + torch.stack([p_indices_grid[d] for d in dims], dim=-1) * stride_, dim=-1 + ) + ) + else: + v_indices_grid.append(torch.tensor(0, device=self.physical.device, dtype=torch.int)) # This is supposed to be a vector of shape d_1 * d_2 ... # whose elements are the coordinates 1 in p_indices_grad[d_1] times stride 1 # plus coordinates 2 in p_indices_grad[d_2] times stride 2, etc... From ebffc3c582321ecdfe1c9fd4935b503a850c07ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 01:01:10 +0100 Subject: [PATCH 091/182] Stop creating index tensors on data device. I think it's ok / faster like that. Add todo to benchmark this later --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index e6133f680..1a01b531b 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -90,11 +90,13 @@ def to_dense( for dims in self.v_to_ps ] - p_index_ranges = [torch.arange(s, device=self.physical.device) for s in self.physical.shape] + # TODO: I think it's ok to create index tensors on CPU when tensor to index is on cuda. Idk + # what's faster + p_index_ranges = [torch.arange(s) for s in self.physical.shape] p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij") v_indices_grid = list[Tensor]() for stride, dims in zip(strides, self.v_to_ps): - stride_ = torch.tensor(stride, device=self.physical.device, dtype=torch.int) + stride_ = torch.tensor(stride, dtype=torch.int) if len(dims) > 0: v_indices_grid.append( @@ -103,7 +105,7 @@ def to_dense( ) ) else: - v_indices_grid.append(torch.tensor(0, device=self.physical.device, dtype=torch.int)) + v_indices_grid.append(torch.tensor(0, dtype=torch.int)) # This is supposed to be a vector of shape d_1 * d_2 ... # whose elements are the coordinates 1 in p_indices_grad[d_1] times stride 1 # plus coordinates 2 in p_indices_grad[d_2] times stride 2, etc... From fe84b80ab504abd5bef23942b8f2182d37edbe65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 01:36:17 +0100 Subject: [PATCH 092/182] Rename current_product to current_size --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 1a01b531b..8b7581cfa 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -369,17 +369,17 @@ def view_default(t: DiagonalSparseTensor, shape: list[int]) -> DiagonalSparseTen p_shape = t.physical.shape for s in shape: group = [] - current_product = 1 + current_size = 1 - while current_product < s: + while current_size < s: if idx >= len(flat_v_to_ps): raise ValueError() group.append(flat_v_to_ps[idx]) - current_product *= p_shape[flat_v_to_ps[idx]] + current_size *= p_shape[flat_v_to_ps[idx]] idx += 1 - if current_product > s: + if current_size > s: raise ValueError() new_v_to_ps.append(group) From 92bcaa67052e838e0ab8765dd3f0f763f95ee9e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 01:37:18 +0100 Subject: [PATCH 093/182] Add unsquash_pdim --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 19 ++++++++ .../sparse/test_diagonal_sparse_tensor.py | 44 +++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 8b7581cfa..5b7eda915 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -353,6 +353,25 @@ def unsqueeze_default(t: DiagonalSparseTensor, dim: int) -> DiagonalSparseTensor return DiagonalSparseTensor(t.physical, new_v_to_ps) +def unsquash_pdim( + physical: Tensor, v_to_ps: list[list[int]], pdim: int, new_pdim_shape: list[int] +) -> tuple[Tensor, list[list[int]]]: + new_shape = list(physical.shape) + new_shape = new_shape[:pdim] + new_pdim_shape + new_shape[pdim + 1 :] + new_physical = physical.reshape(new_shape) + + def new_encodings(d: int) -> list[int]: + if d < pdim: + return [d] + elif d > pdim: + return [d + len(new_pdim_shape) - 1] + else: + return [pdim + i for i in range(len(new_pdim_shape))] + + new_v_to_ps = [[new_d for d in dims for new_d in new_encodings(d)] for dims in v_to_ps] + return new_physical, new_v_to_ps + + @DiagonalSparseTensor.implements(aten.view.default) def view_default(t: DiagonalSparseTensor, shape: list[int]) -> DiagonalSparseTensor: # TODO: add error message when error is raised diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index 5f992f0d2..c000dde20 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -12,6 +12,7 @@ encode_by_order, fix_ungrouped_dims, get_groupings, + unsquash_pdim, ) @@ -195,3 +196,46 @@ def test_fix_ungrouped_dims( assert list(fixed_physical.shape) == expected_physical_shape assert fixed_v_to_ps == expected_v_to_ps + + +@mark.parametrize( + [ + "physical_shape", + "v_to_ps", + "pdim", + "new_pdim_shape", + "expected_physical_shape", + "expected_v_to_ps", + ], + [ + ([4], [[0], [0]], 0, [4], [4], [[0], [0]]), # trivial + ([4], [[0], [0]], 0, [2, 2], [2, 2], [[0, 1], [0, 1]]), + ( + [3, 4, 5], + [[0, 1, 2], [1], [2, 1, 0], [0, 0, 1, 1, 2, 2], [2, 2, 1, 1, 0, 0]], + 1, + [2, 1, 1, 2], + [3, 2, 1, 1, 2, 5], + [ + [0, 1, 2, 3, 4, 5], + [1, 2, 3, 4], + [5, 1, 2, 3, 4, 0], + [0, 0, 1, 2, 3, 4, 1, 2, 3, 4, 5, 5], + [5, 5, 1, 2, 3, 4, 1, 2, 3, 4, 0, 0], + ], + ), + ], +) +def test_unsquash_pdim( + physical_shape: list[int], + v_to_ps: list[list[int]], + pdim: int, + new_pdim_shape: list[int], + expected_physical_shape: list[int], + expected_v_to_ps: list[list[int]], +): + physical = torch.randn(physical_shape) + new_physical, new_v_to_ps = unsquash_pdim(physical, v_to_ps, pdim, new_pdim_shape) + + assert list(new_physical.shape) == expected_physical_shape + assert new_v_to_ps == expected_v_to_ps From ceeeea677ea28b5ea9fbeea165b182dc4c0f7401 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 03:24:48 +0100 Subject: [PATCH 094/182] Revamp unsquash_dim: * Make it not take v_to_ps as input * Make it return a new encoding as output, mapping each original dim to a list of dimensions. --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 8 ++--- .../sparse/test_diagonal_sparse_tensor.py | 29 +++++-------------- 2 files changed, 11 insertions(+), 26 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 5b7eda915..1c2fe8973 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -354,13 +354,13 @@ def unsqueeze_default(t: DiagonalSparseTensor, dim: int) -> DiagonalSparseTensor def unsquash_pdim( - physical: Tensor, v_to_ps: list[list[int]], pdim: int, new_pdim_shape: list[int] + physical: Tensor, pdim: int, new_pdim_shape: list[int] ) -> tuple[Tensor, list[list[int]]]: new_shape = list(physical.shape) new_shape = new_shape[:pdim] + new_pdim_shape + new_shape[pdim + 1 :] new_physical = physical.reshape(new_shape) - def new_encodings(d: int) -> list[int]: + def new_encoding_fn(d: int) -> list[int]: if d < pdim: return [d] elif d > pdim: @@ -368,8 +368,8 @@ def new_encodings(d: int) -> list[int]: else: return [pdim + i for i in range(len(new_pdim_shape))] - new_v_to_ps = [[new_d for d in dims for new_d in new_encodings(d)] for dims in v_to_ps] - return new_physical, new_v_to_ps + new_encoding = [new_encoding_fn(d) for d in range(len(physical.shape))] + return new_physical, new_encoding @DiagonalSparseTensor.implements(aten.view.default) diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index c000dde20..5f8a2dbac 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -201,41 +201,26 @@ def test_fix_ungrouped_dims( @mark.parametrize( [ "physical_shape", - "v_to_ps", "pdim", "new_pdim_shape", "expected_physical_shape", - "expected_v_to_ps", + "expected_new_encoding", ], [ - ([4], [[0], [0]], 0, [4], [4], [[0], [0]]), # trivial - ([4], [[0], [0]], 0, [2, 2], [2, 2], [[0, 1], [0, 1]]), - ( - [3, 4, 5], - [[0, 1, 2], [1], [2, 1, 0], [0, 0, 1, 1, 2, 2], [2, 2, 1, 1, 0, 0]], - 1, - [2, 1, 1, 2], - [3, 2, 1, 1, 2, 5], - [ - [0, 1, 2, 3, 4, 5], - [1, 2, 3, 4], - [5, 1, 2, 3, 4, 0], - [0, 0, 1, 2, 3, 4, 1, 2, 3, 4, 5, 5], - [5, 5, 1, 2, 3, 4, 1, 2, 3, 4, 0, 0], - ], - ), + ([4], 0, [4], [4], [[0]]), # trivial + ([4], 0, [2, 2], [2, 2], [[0, 1]]), + ([3, 4, 5], 1, [2, 1, 1, 2], [3, 2, 1, 1, 2, 5], [[0], [1, 2, 3, 4], [5]]), ], ) def test_unsquash_pdim( physical_shape: list[int], - v_to_ps: list[list[int]], pdim: int, new_pdim_shape: list[int], expected_physical_shape: list[int], - expected_v_to_ps: list[list[int]], + expected_new_encoding: list[list[int]], ): physical = torch.randn(physical_shape) - new_physical, new_v_to_ps = unsquash_pdim(physical, v_to_ps, pdim, new_pdim_shape) + new_physical, new_encoding = unsquash_pdim(physical, pdim, new_pdim_shape) assert list(new_physical.shape) == expected_physical_shape - assert new_v_to_ps == expected_v_to_ps + assert new_encoding == expected_new_encoding From 6ab6b4cbd9ce5cb74cfddcdfc43b7721821d05ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 03:29:53 +0100 Subject: [PATCH 095/182] Add possibility to unsquash physical dimensions in view. --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 42 +++++++++++++++---- .../sparse/test_diagonal_sparse_tensor.py | 1 + 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 1c2fe8973..1d8540399 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -385,28 +385,56 @@ def view_default(t: DiagonalSparseTensor, shape: list[int]) -> DiagonalSparseTen new_v_to_ps = [] idx = 0 flat_v_to_ps = [dim for dims in t.v_to_ps for dim in dims] - p_shape = t.physical.shape + new_physical = t.physical for s in shape: group = [] current_size = 1 while current_size < s: if idx >= len(flat_v_to_ps): + # TODO: I don't think this can happen, need to review and remove if I'm right. raise ValueError() - group.append(flat_v_to_ps[idx]) - current_size *= p_shape[flat_v_to_ps[idx]] - idx += 1 + pdim = flat_v_to_ps[idx] + pdim_size = new_physical.shape[pdim] - if current_size > s: - raise ValueError() + if current_size * pdim_size > s: + # Need to split physical dimension + if s % current_size != 0: + raise ValueError("Can't split physical dimension") + + new_pdim_first_dim_size = s // current_size + + if pdim_size % new_pdim_first_dim_size != 0: + raise ValueError("Can't split physical dimension") + + new_pdim_shape = [new_pdim_first_dim_size, pdim_size // new_pdim_first_dim_size] + new_physical, new_encoding = unsquash_pdim(new_physical, pdim, new_pdim_shape) + + new_v_to_ps = [ + [new_d for d in dims for new_d in new_encoding[d]] for dims in new_v_to_ps + ] + # A bit of a weird trick here. We want to re-encode flat_v_to_ps according to + # new_encoding. However, re-encoding elements before idx would potentially change + # the length of the list before idx, so idx would not have the right value anymore. + # Since we don't need the elements of flat_v_to_ps that are before idx anyway, we + # just get rid of them and re-encode flat_v_to_ps[idx:] instead, and reset idx to 0 + # to say that we're back at the beginning of this new list. + flat_v_to_ps = [new_d for d in flat_v_to_ps[idx:] for new_d in new_encoding[d]] + idx = 0 + + group.append(pdim) + current_size *= new_physical.shape[pdim] + idx += 1 new_v_to_ps.append(group) if idx != len(flat_v_to_ps): raise ValueError(f"idx != len(flat_v_to_ps). {idx}; {flat_v_to_ps}; {shape}; {t.v_to_ps}") - return DiagonalSparseTensor(t.physical, new_v_to_ps) + # The above code does not handle physical dimension squashing, so the physical is not + # necessarily maximally squashed at this point, so we need the safe constructor. + return make_dst(new_physical, new_v_to_ps) @DiagonalSparseTensor.implements(aten.expand.default) diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index 5f8a2dbac..a2c73f08b 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -122,6 +122,7 @@ def test_unary(func): ([2, 3], [[0], [0], [1]], [12], [2, 3], [[0, 0, 1]]), # squashing 3 dims ([2, 3], [[0, 0, 1]], [2, 2, 3], [2, 3], [[0], [0], [1]]), # unsquashing into 3 dims ([4], [[0], [0]], [2, 2, 4], [2, 2], [[0], [1], [0, 1]]), # unsquashing physical dim + ([4], [[0], [0]], [4, 2, 2], [2, 2], [[0, 1], [0], [1]]), # unsquashing physical dim ([2, 3, 4], [[0], [0], [1], [2]], [4, 12], [2, 12], [[0, 0], [1]]), # world boss ([2, 12], [[0, 0], [1]], [2, 2, 3, 4], [2, 3, 4], [[0], [0], [1], [2]]), # world boss ], From b746af8fc726c792a40d272de51ff03b5e46a9b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 03:48:51 +0100 Subject: [PATCH 096/182] Fix get_groupings and add test that failed before this fix --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 3 ++- tests/unit/sparse/test_diagonal_sparse_tensor.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 1d8540399..2d3e1a842 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -202,7 +202,8 @@ def get_groupings(v_to_ps: list[list[int]]) -> list[list[int]]: if i in visited_is: continue - groups.append(group) + available_dims = set(group) - visited_is + groups.append(list(available_dims)) visited_is.update(set(group)) return groups diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index a2c73f08b..aff989a53 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -184,6 +184,7 @@ def test_get_groupings(v_to_ps: list[list[int]], expected_groupings: list[list[i ["physical_shape", "v_to_ps", "expected_physical_shape", "expected_v_to_ps"], [ ([3, 4, 5], [[0, 1, 2], [2, 0, 1], [2]], [12, 5], [[0, 1], [1, 0], [1]]), + ([32, 20, 8], [[0], [1, 0], [2]], [32, 20, 8], [[0], [1, 0], [2]]), ], ) def test_fix_ungrouped_dims( From 59d1f2c3fdc177fb94924c1105b4e136db0a1361 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 03:55:39 +0100 Subject: [PATCH 097/182] Fix dim encoding in transpose_int --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 2d3e1a842..82446dac7 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -534,7 +534,8 @@ def transpose_int(t: DiagonalSparseTensor, dim0: int, dim1: int) -> DiagonalSpar new_v_to_ps[dim0] = t.v_to_ps[dim1] new_v_to_ps[dim1] = t.v_to_ps[dim0] - return DiagonalSparseTensor(t.physical, new_v_to_ps) + new_physical, new_v_to_ps = fix_dim_encoding(t.physical, new_v_to_ps) + return DiagonalSparseTensor(new_physical, new_v_to_ps) def einsum( From 4c48523352af7990f68b8194be147d40b39aec2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 04:00:31 +0100 Subject: [PATCH 098/182] Name more variables in __torch_dispatch__ --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 82446dac7..3df1026af 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -141,7 +141,10 @@ def unwrap_to_dense(t: Tensor): for k, v in kwargs.items(): print(f" > {k}: {v}") print() - return func(*tree_map(unwrap_to_dense, args), **tree_map(unwrap_to_dense, kwargs)) + + unwrapped_args = tree_map(unwrap_to_dense, args) + unwrapped_kwargs = tree_map(unwrap_to_dense, kwargs) + return func(*unwrapped_args, **unwrapped_kwargs) def __repr__(self, *, tensor_contents=None) -> str: return f"DiagonalSparseTensor(physical={self.physical}, v_to_ps={self.v_to_ps})" From 123ac3c2e733ba276f0fad9c58bc6f043a41da08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 04:07:55 +0100 Subject: [PATCH 099/182] Remove todos in view_default --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 3df1026af..0db69a968 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -378,9 +378,6 @@ def new_encoding_fn(d: int) -> list[int]: @DiagonalSparseTensor.implements(aten.view.default) def view_default(t: DiagonalSparseTensor, shape: list[int]) -> DiagonalSparseTensor: - # TODO: add error message when error is raised - # TODO: handle case where the physical has to be reshaped - assert isinstance(t, DiagonalSparseTensor) if prod(shape) != t.numel(): From 469ddb3db717d07fc847b121dce39c6deef34042 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 04:23:41 +0100 Subject: [PATCH 100/182] Add support for shape inference in view_default * This handles the case where an element in shape is -1 --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 0db69a968..792d29a64 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -376,10 +376,23 @@ def new_encoding_fn(d: int) -> list[int]: return new_physical, new_encoding +def infer_shape(shape: list[int], numel: int) -> list[int]: + if shape.count(-1) > 1: + raise ValueError("Only one dimension can be inferred") + known = 1 + for s in shape: + if s != -1: + known *= s + inferred = numel // known + return [inferred if s == -1 else s for s in shape] + + @DiagonalSparseTensor.implements(aten.view.default) def view_default(t: DiagonalSparseTensor, shape: list[int]) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) + shape = infer_shape(shape, t.numel()) + if prod(shape) != t.numel(): raise ValueError(f"shape '{shape}' is invalid for input of size {t.numel()}") From bd3e569280f51574b89f8bd81286dd337b543ce0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 04:24:09 +0100 Subject: [PATCH 101/182] Add _unsafe_view_default --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 792d29a64..524ef5820 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -451,6 +451,13 @@ def view_default(t: DiagonalSparseTensor, shape: list[int]) -> DiagonalSparseTen return make_dst(new_physical, new_v_to_ps) +@DiagonalSparseTensor.implements(aten._unsafe_view.default) +def _unsafe_view_default(t: DiagonalSparseTensor, shape: list[int]) -> DiagonalSparseTensor: + return view_default( + t, shape + ) # We don't do the optimizations that they do in https://github.com/pytorch/pytorch/blame/main/aten/src/ATen/native/TensorShape.cpp + + @DiagonalSparseTensor.implements(aten.expand.default) def expand_default(t: DiagonalSparseTensor, sizes: list[int]) -> DiagonalSparseTensor: # note that sizes could also be just an int, or a torch.Size i think From 42462bcf465a233a3610c2b1b87a1ded46a1af19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 04:32:42 +0100 Subject: [PATCH 102/182] Add threshold_backward_default * Used in AlexNet, probably for ReLU backward. * It's pointwise so trivial implementation --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 524ef5820..0f902b18a 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -501,6 +501,15 @@ def div_Scalar(t: DiagonalSparseTensor, divisor: float) -> DiagonalSparseTensor: return DiagonalSparseTensor(new_physical, t.v_to_ps) +@DiagonalSparseTensor.implements(aten.threshold_backward.default) +def threshold_backward_default( + grad_output: DiagonalSparseTensor, self: Tensor, threshold +) -> DiagonalSparseTensor: + new_physical = aten.threshold_backward.default(grad_output.physical, self, threshold) + + return DiagonalSparseTensor(new_physical, grad_output.v_to_ps) + + @DiagonalSparseTensor.implements(aten.slice.Tensor) def slice_Tensor( t: DiagonalSparseTensor, dim: int, start: int | None, end: int | None, step: int = 1 From 19e049da6e353c30b923bfd51c8fba4eb3025d08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 05:29:22 +0100 Subject: [PATCH 103/182] Add sum_dim_IntList --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 0f902b18a..365f88ed4 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -317,6 +317,23 @@ def sum_default(t: DiagonalSparseTensor) -> Tensor: return aten.sum.default(t.physical) +@DiagonalSparseTensor.implements(aten.sum.dim_IntList) +def sum_dim_IntList(t: DiagonalSparseTensor, dim: list[int], keepdim: bool, dtype=None) -> Tensor: + assert isinstance(t, DiagonalSparseTensor) + + if dtype: + raise NotImplementedError() + + all_dims = list(range(t.ndim)) + result = einsum((t, all_dims), output=[d for d in all_dims if d not in dim]) + + if keepdim: + for d in dim: + result = result.unsqueeze(d) + + return result + + @DiagonalSparseTensor.implements(aten.pow.Tensor_Scalar) def pow_Tensor_Scalar(t: DiagonalSparseTensor, exponent: float) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) From 519d11b3f88d5f9d18d562973036d4d1f77e526d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 06:14:55 +0100 Subject: [PATCH 104/182] Add broadcast_tensors_default --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 365f88ed4..c40b231c1 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -510,6 +510,33 @@ def expand_default(t: DiagonalSparseTensor, sizes: list[int]) -> DiagonalSparseT return DiagonalSparseTensor(new_physical, new_v_to_ps) +@DiagonalSparseTensor.implements(aten.broadcast_tensors.default) +def broadcast_tensors_default(tensors: list[Tensor]) -> tuple[Tensor, Tensor]: + if len(tensors) != 2: + raise NotImplementedError() + + t1, t2 = tensors + + if t1.shape == t2.shape: + return t1, t2 + + a = t1 if t1.ndim >= t2.ndim else t2 + b = t2 if t1.ndim >= t2.ndim else t1 + + a_shape = list(a.shape) + padded_b_shape = [1] * (a.ndim - b.ndim) + list(b.shape) + + new_shape = list[int]() + + for s_a, s_b in zip(a_shape, padded_b_shape): + if s_a != 1 and s_b != 1 and s_a != s_b: + raise ValueError("Incompatible shapes for broadcasting") + else: + new_shape.append(max(s_a, s_b)) + + return aten.expand.default(t1, new_shape), aten.expand.default(t2, new_shape) + + @DiagonalSparseTensor.implements(aten.div.Scalar) def div_Scalar(t: DiagonalSparseTensor, divisor: float) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) From 64d0ae8f38d48900959e96b210d7625996918151 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 06:15:05 +0100 Subject: [PATCH 105/182] Fix mul_Tensor --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index c40b231c1..27a50965a 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -592,11 +592,15 @@ def slice_Tensor( @DiagonalSparseTensor.implements(aten.mul.Tensor) def mul_Tensor(t1: Tensor, t2: Tensor) -> DiagonalSparseTensor: - # Element-wise multiplication + # Element-wise multiplication with broadcasting assert isinstance(t1, DiagonalSparseTensor) or isinstance(t2, DiagonalSparseTensor) - new_physical = aten.mul.Tensor(t1, t2.physical) - return DiagonalSparseTensor(new_physical, t2.v_to_ps) + t1_, t2_ = aten.broadcast_tensors.default([t1, t2]) + t1_ = to_diagonal_sparse_tensor(t1_) + t2_ = to_diagonal_sparse_tensor(t2_) + + all_dims = list(range(t1_.ndim)) + return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) @DiagonalSparseTensor.implements(aten.transpose.int) From 47f32f8103899fa098d6b619250d860116ecc6f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 06:16:05 +0100 Subject: [PATCH 106/182] Add missing default value for keepdim in sum_dim_IntList --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 27a50965a..e95f8dcac 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -318,7 +318,9 @@ def sum_default(t: DiagonalSparseTensor) -> Tensor: @DiagonalSparseTensor.implements(aten.sum.dim_IntList) -def sum_dim_IntList(t: DiagonalSparseTensor, dim: list[int], keepdim: bool, dtype=None) -> Tensor: +def sum_dim_IntList( + t: DiagonalSparseTensor, dim: list[int], keepdim: bool = False, dtype=None +) -> Tensor: assert isinstance(t, DiagonalSparseTensor) if dtype: From b4eb02123188db900a39d4663f693b7744303bf1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 06:49:29 +0100 Subject: [PATCH 107/182] Improve print when falling back to dense * Support tensor lists * Print pshape and v_to_ps for DSTs --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index e95f8dcac..84ac9c7a6 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -128,12 +128,22 @@ def unwrap_to_dense(t: Tensor): else: return t + def tensor_to_str(tensor: Tensor) -> str: + result = f"{tensor.__class__.__name__} - shape: {tensor.shape}" + if isinstance(tensor, DiagonalSparseTensor): + result += f" - pshape: {tensor.physical.shape} - v_to_ps: {tensor.v_to_ps}" + + return result + print(f"Falling back to dense for {func.__name__}") if len(args) > 0: print("* args:") for arg in args: if isinstance(arg, Tensor): - print(f" > {arg.__class__.__name__} - {arg.shape}") + print(f" > {tensor_to_str(arg)}") + elif isinstance(arg, list) and len(arg) > 0 and isinstance(arg[0], Tensor): + list_content = "\n ".join([tensor_to_str(t) for t in arg]) + print(f" > [{list_content}]") else: print(f" > {arg}") if len(kwargs) > 0: From 909bcb70a79d285c5cdcf0ac11acabad8bb4afe1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 06:51:52 +0100 Subject: [PATCH 108/182] Add mul_Scalar --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 84ac9c7a6..027edd62e 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -615,6 +615,16 @@ def mul_Tensor(t1: Tensor, t2: Tensor) -> DiagonalSparseTensor: return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) +@DiagonalSparseTensor.implements(aten.mul.Scalar) +def mul_Scalar(t: DiagonalSparseTensor, scalar) -> DiagonalSparseTensor: + # TODO: maybe it could be that scalar is a scalar DST and t is a normal tensor. Need to check + # that + + assert isinstance(t, DiagonalSparseTensor) + new_physical = aten.mul.Scalar(t.physical, scalar) + return DiagonalSparseTensor(new_physical, t.v_to_ps) + + @DiagonalSparseTensor.implements(aten.transpose.int) def transpose_int(t: DiagonalSparseTensor, dim0: int, dim1: int) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) From c72f99e85dfca4a8925403883ef42f0eb0c9848f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 07:10:34 +0100 Subject: [PATCH 109/182] Fix mul_Tensor to be able to handle non-tensor input --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 027edd62e..c847c62b2 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -603,11 +603,21 @@ def slice_Tensor( @DiagonalSparseTensor.implements(aten.mul.Tensor) -def mul_Tensor(t1: Tensor, t2: Tensor) -> DiagonalSparseTensor: +def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> DiagonalSparseTensor: # Element-wise multiplication with broadcasting assert isinstance(t1, DiagonalSparseTensor) or isinstance(t2, DiagonalSparseTensor) - t1_, t2_ = aten.broadcast_tensors.default([t1, t2]) + if isinstance(t1, int) or isinstance(t1, float): + t1_ = torch.tensor(t1, device=t2.device) + else: + t1_ = t1 + + if isinstance(t2, int) or isinstance(t2, float): + t2_ = torch.tensor(t2, device=t1.device) + else: + t2_ = t2 + + t1_, t2_ = aten.broadcast_tensors.default([t1_, t2_]) t1_ = to_diagonal_sparse_tensor(t1_) t2_ = to_diagonal_sparse_tensor(t2_) From 4b3593b4d4386b2e077a408c90616e65830587e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 07:24:30 +0100 Subject: [PATCH 110/182] Make test_einsum parametrizable --- .../sparse/test_diagonal_sparse_tensor.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index aff989a53..35f09a8fb 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -28,13 +28,25 @@ def test_to_dense(): assert c[i, j, j, i] == a[i, j] -def test_einsum(): - a = DiagonalSparseTensor(torch.randn([4, 5]), [[0], [0], [1]]) - b = DiagonalSparseTensor(torch.randn([4, 5]), [[0], [1], [1]]) +@mark.parametrize( + ["a_pshape", "a_v_to_ps", "b_pshape", "b_v_to_ps", "a_indices", "b_indices", "output_indices"], + [([4, 5], [[0], [0], [1]], [4, 5], [[0], [1], [1]], [0, 1, 2], [0, 2, 3], [0, 1, 3])], +) +def test_einsum( + a_pshape: list[int], + a_v_to_ps: list[list[int]], + b_pshape: list[int], + b_v_to_ps: list[list[int]], + a_indices: list[int], + b_indices: list[int], + output_indices: list[int], +): + a = DiagonalSparseTensor(torch.randn(a_pshape), a_v_to_ps) + b = DiagonalSparseTensor(torch.randn(b_pshape), b_v_to_ps) - res = einsum((a, [0, 1, 2]), (b, [0, 2, 3]), output=[0, 1, 3]) + res = einsum((a, a_indices), (b, b_indices), output=output_indices) - expected = torch.einsum("ijk,ikl->ijl", a.to_dense(), b.to_dense()) + expected = torch.einsum(a.to_dense(), a_indices, b.to_dense(), b_indices, output_indices) assert_close(res.to_dense(), expected) From 2052d81a2903823b286af0ad996722271875018a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 07:31:07 +0100 Subject: [PATCH 111/182] Add failing example to test_einsum * This is the same setting as what happens in WithMultiHeadAttention()-32 during matrix multiplication, but with reduced dimension sizes to make the test faster --- tests/unit/sparse/test_diagonal_sparse_tensor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index 35f09a8fb..25553016b 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -30,7 +30,10 @@ def test_to_dense(): @mark.parametrize( ["a_pshape", "a_v_to_ps", "b_pshape", "b_v_to_ps", "a_indices", "b_indices", "output_indices"], - [([4, 5], [[0], [0], [1]], [4, 5], [[0], [1], [1]], [0, 1, 2], [0, 2, 3], [0, 1, 3])], + [ + ([4, 5], [[0], [0], [1]], [4, 5], [[0], [1], [1]], [0, 1, 2], [0, 2, 3], [0, 1, 3]), + ([2, 3, 5], [[0, 1], [2, 0]], [10, 3], [[0], [1]], [0, 1], [1, 2], [0, 2]), + ], ) def test_einsum( a_pshape: list[int], From a1ccad54ee634bb934b220a890066cecc504b668 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 07:33:20 +0100 Subject: [PATCH 112/182] Add even simpler failing example in test_einsum --- tests/unit/sparse/test_diagonal_sparse_tensor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index 25553016b..176913b75 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -33,6 +33,7 @@ def test_to_dense(): [ ([4, 5], [[0], [0], [1]], [4, 5], [[0], [1], [1]], [0, 1, 2], [0, 2, 3], [0, 1, 3]), ([2, 3, 5], [[0, 1], [2, 0]], [10, 3], [[0], [1]], [0, 1], [1, 2], [0, 2]), + ([2, 3], [[0, 1]], [6], [[0]], [0], [0], []), ], ) def test_einsum( From 2675d8e371e3af902b46d0c081697e8945f6587c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 29 Oct 2025 07:39:16 +0100 Subject: [PATCH 113/182] Add assertion about result being DST in test_einsum --- tests/unit/sparse/test_diagonal_sparse_tensor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index 176913b75..e8df395a2 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -51,6 +51,8 @@ def test_einsum( res = einsum((a, a_indices), (b, b_indices), output=output_indices) expected = torch.einsum(a.to_dense(), a_indices, b.to_dense(), b_indices, output_indices) + + assert isinstance(res, DiagonalSparseTensor) assert_close(res.to_dense(), expected) From 3121b6a5a968bd9601abd9d28007fa729b437eeb Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 29 Oct 2025 09:37:08 +0100 Subject: [PATCH 114/182] Add a test to get_groupings. May want to add the same to `test_fix_ungrouped_dims` --- tests/unit/sparse/test_diagonal_sparse_tensor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index 5f992f0d2..b21001a0d 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -169,6 +169,7 @@ def test_encode_by_order( ["v_to_ps", "expected_groupings"], [ ([[0, 1, 2], [2, 0, 1], [2]], [[0, 1], [2]]), + ([[0, 1, 0, 1]], [[0, 1]]), ], ) def test_get_groupings(v_to_ps: list[list[int]], expected_groupings: list[list[int]]): From 5a2a7a0a359c1dd94e8e03a8d987412b034b7f4e Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 29 Oct 2025 13:58:01 +0100 Subject: [PATCH 115/182] Add `strides_from_p_dims_and_p_shape` --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index c847c62b2..d61bc41bc 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -84,11 +84,8 @@ def to_dense( # and so, if we index dimension i with j=j_0 * stride[0] + j_1 * stride[1] + j_2 * stride[2] # which isa unique decomposition, then this corresponds to indexing dimensions v_to_ps[i] at # indices [j_0, j_1, j_2] - s = self.physical.shape - strides = [ - list(accumulate([1] + [s[dim] for dim in dims[:0:-1]], operator.mul))[::-1] - for dims in self.v_to_ps - ] + s = list(self.physical.shape) + strides = [strides_from_p_dims_and_p_shape(dims, s) for dims in self.v_to_ps] # TODO: I think it's ok to create index tensors on CPU when tensor to index is on cuda. Idk # what's faster @@ -181,6 +178,12 @@ def decorator(func): return decorator +def strides_from_p_dims_and_p_shape(p_dims: list[int], physical_shape: list[int]) -> list[int]: + return list(accumulate([1] + [physical_shape[dim] for dim in p_dims[:0:-1]], operator.mul))[ + ::-1 + ] + + def p_to_vs_from_v_to_ps(v_to_ps: list[list[int]]) -> list[list[tuple[int, int]]]: """ A physical dimension is mapped to a list of couples of the form @@ -650,6 +653,18 @@ def transpose_int(t: DiagonalSparseTensor, dim0: int, dim1: int) -> DiagonalSpar def einsum( *args: tuple[DiagonalSparseTensor, list[int]], output: list[int] ) -> DiagonalSparseTensor: + + # First part of the algorithm, determine how to cluster physical indices as well as the common + # p_shapes corresponding to matching v_dims. Second part translates to physical einsum. + + # new plan for first part: + # get a map from einsum index to (tensor_idx, v_dims) + # an index in the physical einsum is uniquely characterized by a virtual einsum index and a + # stride corresponding to the physical stride in the virtual one (note that as the virtual shape + # for two virtual index that match should match, then we want to match the strides and reshape + # accordingly). + # We want to cluster such indices whenever several appear in the same p_to_vs + # TODO: Handle ellipsis # If we have an index v for some virtual dim whose corresponding v_to_ps is a non-trivial list # [p_1, ..., p_k], then we have to create fresh sub-indices for each dimension. From 4b8364494ab6084e31e9f85d9894597c2af10a7c Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 29 Oct 2025 14:05:56 +0100 Subject: [PATCH 116/182] Add `merge_strides` --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index d61bc41bc..82cd402b4 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -184,6 +184,10 @@ def strides_from_p_dims_and_p_shape(p_dims: list[int], physical_shape: list[int] ] +def merge_strides(strides: list[list[int]]) -> list[int]: + return sorted({s for stride in strides for s in stride}) + + def p_to_vs_from_v_to_ps(v_to_ps: list[list[int]]) -> list[list[tuple[int, int]]]: """ A physical dimension is mapped to a list of couples of the form From 7e8439addceaf5b408e354bc356d97aa7351453f Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 29 Oct 2025 14:25:30 +0100 Subject: [PATCH 117/182] Add `stride_to_shape` --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 82cd402b4..f5985c1cd 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -188,6 +188,11 @@ def merge_strides(strides: list[list[int]]) -> list[int]: return sorted({s for stride in strides for s in stride}) +def stride_to_shape(numel: int, stride: list[int]) -> list[int]: + augmented_stride = [numel] + stride + return [a // b for a, b in zip(augmented_stride[:-1], augmented_stride[1:])] + + def p_to_vs_from_v_to_ps(v_to_ps: list[list[int]]) -> list[list[tuple[int, int]]]: """ A physical dimension is mapped to a list of couples of the form From c7047e18393d6f101b27eb6cb17d51f04f570ae8 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 29 Oct 2025 14:49:06 +0100 Subject: [PATCH 118/182] Add `to_target_physical_strides` --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index f5985c1cd..06b82b16b 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -286,6 +286,32 @@ def to_diagonal_sparse_tensor(t: Tensor) -> DiagonalSparseTensor: return make_dst(t, [[i] for i in range(t.ndim)]) +def to_target_physical_strides( + physical: Tensor, v_to_ps: list[list[int]], strides: list[list[int]] +) -> tuple[Tensor, list[list[int]]]: + current_strides = [ + strides_from_p_dims_and_p_shape(p_dims, list(physical.shape)) for p_dims in v_to_ps + ] + target_stride = merge_strides(strides) + + numel = physical.numel() + target_shape = stride_to_shape(numel, target_stride) + new_physical = physical.reshape(target_shape) + + stride_to_p_dim = {s: i for i, s in enumerate(target_stride)} + stride_to_p_dim[0] = len(target_shape) + + new_v_to_ps = list[list[int]]() + for stride in current_strides: + extended_stride = stride + [0] + new_p_dims = list[int]() + for s_curr, s_next in zip(extended_stride[:-1], extended_stride[1:]): + new_p_dims += range(stride_to_p_dim[s_curr], stride_to_p_dim[s_next]) + new_v_to_ps.append(new_p_dims) + + return new_physical, new_v_to_ps + + def fix_dim_encoding(physical: Tensor, v_to_ps: list[list[int]]) -> tuple[Tensor, list[list[int]]]: v_to_ps, destination = encode_v_to_ps(v_to_ps) source = list(range(physical.ndim)) From c6e3fd9b7677422ce0265f2a29091897214b0503 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 29 Oct 2025 15:23:31 +0100 Subject: [PATCH 119/182] Add new_implementation idea in einsum. --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 06b82b16b..a538d17bc 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -692,8 +692,18 @@ def einsum( # First part of the algorithm, determine how to cluster physical indices as well as the common # p_shapes corresponding to matching v_dims. Second part translates to physical einsum. - # new plan for first part: # get a map from einsum index to (tensor_idx, v_dims) + # get a map from einsum index to merge of strides corresponding to v_dims with that index + # use to_target_physical_strides on each physical and v_to_ps + # cluster pairs of (einsum_index, new_stride) using new_v_to_ps and possibly its corresponding + # p_to_vs + # get unique indices + # map output indices (there can be splits) + # call physical einsum + # build resulting dst + + # OVER + # an index in the physical einsum is uniquely characterized by a virtual einsum index and a # stride corresponding to the physical stride in the virtual one (note that as the virtual shape # for two virtual index that match should match, then we want to match the strides and reshape From d2e53a315315cd72cbc887c15e9503fb0ffe4d10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 30 Oct 2025 08:04:50 +0100 Subject: [PATCH 120/182] Always use randn_ in test --- tests/unit/sparse/test_diagonal_sparse_tensor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index 57f0f36a5..4f31f6eb1 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -45,8 +45,8 @@ def test_einsum( b_indices: list[int], output_indices: list[int], ): - a = DiagonalSparseTensor(torch.randn(a_pshape), a_v_to_ps) - b = DiagonalSparseTensor(torch.randn(b_pshape), b_v_to_ps) + a = DiagonalSparseTensor(randn_(a_pshape), a_v_to_ps) + b = DiagonalSparseTensor(randn_(b_pshape), b_v_to_ps) res = einsum((a, a_indices), (b, b_indices), output=output_indices) @@ -212,7 +212,7 @@ def test_fix_ungrouped_dims( expected_physical_shape: list[int], expected_v_to_ps: list[list[int]], ): - physical = torch.randn(physical_shape) + physical = randn_(physical_shape) fixed_physical, fixed_v_to_ps = fix_ungrouped_dims(physical, v_to_ps) assert list(fixed_physical.shape) == expected_physical_shape @@ -240,7 +240,7 @@ def test_unsquash_pdim( expected_physical_shape: list[int], expected_new_encoding: list[list[int]], ): - physical = torch.randn(physical_shape) + physical = randn_(physical_shape) new_physical, new_encoding = unsquash_pdim(physical, pdim, new_pdim_shape) assert list(new_physical.shape) == expected_physical_shape From 16e7c1c7422c0cd1597e1b33c83a41f07ba1434f Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 30 Oct 2025 09:42:38 +0100 Subject: [PATCH 121/182] Fix order of sorting in merg_strides --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index a538d17bc..2bb2d40cf 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -185,7 +185,7 @@ def strides_from_p_dims_and_p_shape(p_dims: list[int], physical_shape: list[int] def merge_strides(strides: list[list[int]]) -> list[int]: - return sorted({s for stride in strides for s in stride}) + return sorted({s for stride in strides for s in stride}, reverse=True) def stride_to_shape(numel: int, stride: list[int]) -> list[int]: From 16e6165e9daad12fbc8209a3acd5436e0d427fc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 30 Oct 2025 13:12:41 +0100 Subject: [PATCH 122/182] Add strides_v2 --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 2bb2d40cf..b551a920c 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -184,6 +184,34 @@ def strides_from_p_dims_and_p_shape(p_dims: list[int], physical_shape: list[int] ] +def strides_v2(p_dims: list[int], physical_shape: list[int]) -> list[int]: + """ + From a list of physical dimensions corresponding to a virtual dimension, and from the physical + shape, get the stride indicating how moving on each physical dimension makes you move on the + virtual dimension. + + Example: + Imagine a vector of size 3, and of value [1, 2, 3]. + Imagine a DST t of shape [3, 3] using this vector as physical and using [[0, 0]] as v_to_ps. + t.to_dense() is [1, 0, 0, 0, 2, 0, 0, 0, 3] (it's the flattening of the diagonal matrix + [[1, 0, 0], [0, 2, 0], [0, 0, 3]]). + When you move by 1 on physical dimension 0, you move by 4 on virtual dimension 0, i.e. + strides_v2([0, 0], [3]) = 4 + In the 2D view, you'd move by 1 row (3 indices) and 1 column (1 index). + + Example: + strides_v2([0, 0, 1], [3,4]) # [16, 1] + Moving by 1 on physical dimension 0 makes you move by 16 on the virtual dimension. Moving by + 1 on physical dimension 1 makes you move by 1 on the virtual dimension. + """ + + strides_v1 = strides_from_p_dims_and_p_shape(p_dims, physical_shape) + result = [0 for _ in range(len(physical_shape))] + for i, d in enumerate(p_dims): + result[d] += strides_v1[i] + return result + + def merge_strides(strides: list[list[int]]) -> list[int]: return sorted({s for stride in strides for s in stride}, reverse=True) From 59cf10b326b903b9438d00b908ab4e77f022bad2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 30 Oct 2025 13:13:18 +0100 Subject: [PATCH 123/182] Add more failing parametrizations to test_get_groupings and test_fix_ungrouped_dims, remove print --- tests/unit/sparse/test_diagonal_sparse_tensor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index 4f31f6eb1..7ce66b12f 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -190,11 +190,13 @@ def test_encode_by_order( [ ([[0, 1, 2], [2, 0, 1], [2]], [[0, 1], [2]]), ([[0, 1, 0, 1]], [[0, 1]]), + ([[0, 1, 0, 1, 2]], [[0, 1], [2]]), + ([[0, 0]], [[0, 0]]), + ([[0, 1], [1, 2]], [[0], [1], [2]]), ], ) def test_get_groupings(v_to_ps: list[list[int]], expected_groupings: list[list[int]]): groupings = get_groupings(v_to_ps) - print(groupings) assert groupings == expected_groupings @@ -204,6 +206,7 @@ def test_get_groupings(v_to_ps: list[list[int]], expected_groupings: list[list[i [ ([3, 4, 5], [[0, 1, 2], [2, 0, 1], [2]], [12, 5], [[0, 1], [1, 0], [1]]), ([32, 20, 8], [[0], [1, 0], [2]], [32, 20, 8], [[0], [1, 0], [2]]), + ([3, 3, 4], [[0, 1], [1, 2]], [3, 3, 4], [[0, 1], [1, 2]]), ], ) def test_fix_ungrouped_dims( From e0dc1a705b728bf8f840a39c08e7da1633fa68ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 30 Oct 2025 13:13:54 +0100 Subject: [PATCH 124/182] Add test_concatenate --- .../sparse/test_diagonal_sparse_tensor.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index 7ce66b12f..e0ce0d22c 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -248,3 +248,22 @@ def test_unsquash_pdim( assert list(new_physical.shape) == expected_physical_shape assert new_encoding == expected_new_encoding + + +@mark.parametrize( + ["dst_args", "dim"], + [ + ([([3, 4], [[0], [0, 1]]), ([3, 3, 4], [[0, 1], [1, 2]])], 0), + ([([3, 12], [[0, 1], [0]]), ([9, 4], [[0, 1], [0]])], 1), + ], +) +def test_concatenate( + dst_args: list[tuple[list[int], list[list[int]]]], + dim: int, +): + tensors = [DiagonalSparseTensor(randn_(pshape), v_to_ps) for pshape, v_to_ps in dst_args] + res = aten.cat.default(tensors, dim) + expected = aten.cat.default([t.to_dense() for t in tensors], dim) + + assert isinstance(res, DiagonalSparseTensor) + assert torch.all(torch.eq(res.to_dense(), expected)) From 3dffd1e25025dcfd0c09bf1743de2b3befbb5fbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 30 Oct 2025 15:00:40 +0100 Subject: [PATCH 125/182] Add strides_to_pdims --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index b551a920c..bd411231a 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -212,6 +212,46 @@ def strides_v2(p_dims: list[int], physical_shape: list[int]) -> list[int]: return result +def argmax(iterable): + return max(enumerate(iterable), key=lambda x: x[1])[0] + + +def strides_to_pdims(strides: list[int], physical_shape: list[int]) -> list[int]: + """ + Given a list of strides, find and return the used physical dimensions. + + This algorithm runs in O(n * m) with n the number of physical dimensions (i.e. + len(physical_shape) and len(strides)), and with m the number of pdims in the result. + + I'm pretty sure it could be implemented in O((n+m)log(n)) by using a sorted linked list for the + remaining_strides, and keeping it sorted each time we update it. Argmax would then always be 0, + removing the need to go through the whole list at every iteration. + """ + + # e.g. strides = [22111, 201000], physical_shape = [10, 2] + + pdims = [] + remaining_strides = [s for s in strides] + remaining_numel = ( + sum(remaining_strides[i] * (physical_shape[i] - 1) for i in range(len(physical_shape))) + 1 + ) + # e.g. 9 * 22111 + 1 * 201000 + 1 = 400000 + + while sum(remaining_strides) > 0: + current_pdim = argmax(remaining_strides) + # e.g. 1 + + pdims.append(current_pdim) + + remaining_numel = remaining_numel // physical_shape[current_pdim] + # e.g. 400000 / 2 = 200000 + + remaining_strides[current_pdim] -= remaining_numel + # e.g. [22111, 1000] + + return pdims + + def merge_strides(strides: list[list[int]]) -> list[int]: return sorted({s for stride in strides for s in stride}, reverse=True) From 48387aba9e2e99a80d22765412dc257f89c421db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 31 Oct 2025 07:15:15 +0100 Subject: [PATCH 126/182] Add (passing) test_to_dense2 to test to_dense when the tensor has a virtual dimension that uses a physical dimension multiple times. --- tests/unit/sparse/test_diagonal_sparse_tensor.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index e0ce0d22c..b896fcd2a 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -2,7 +2,7 @@ from pytest import mark from torch.ops import aten # type: ignore from torch.testing import assert_close -from utils.tensors import randn_, zeros_ +from utils.tensors import randn_, tensor_, zeros_ from torchjd.sparse._diagonal_sparse_tensor import ( _IN_PLACE_POINTWISE_FUNCTIONS, @@ -28,6 +28,14 @@ def test_to_dense(): assert c[i, j, j, i] == a[i, j] +def test_to_dense2(): + a = tensor_([1.0, 2.0, 3.0]) + b = DiagonalSparseTensor(a, [[0, 0]]) + c = b.to_dense() + expected = tensor_([1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]) + assert torch.all(torch.eq(c, expected)) + + @mark.parametrize( ["a_pshape", "a_v_to_ps", "b_pshape", "b_v_to_ps", "a_indices", "b_indices", "output_indices"], [ From f87ecbb8370eb884f3697a2cc805ee5aef78bdb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 31 Oct 2025 07:18:23 +0100 Subject: [PATCH 127/182] Use new strides in to_dense * The result is the same as before * Before that we only iterated on the pdims used by each virtual dim, and summed them if a pdim was present multiple times. * Now the new stride is already the sum of the old strides when a pdim is present multiple times in a vdim. We iterate over all dimensions, because for dimensions not present in the vdim, the stride is simply 0. * There's probably a more efficient implementation --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 26 +++++-------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index bd411231a..5e6ba598a 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -76,36 +76,22 @@ def to_dense( if self.physical.ndim == 0: return self.physical - # This is a list of strides whose shape matches that of v_to_ps except that each element - # is the stride factor of the index to get the right element for the corresponding virtual - # dimension. Stride is the jump necessary to go from one element to the next one in the - # specified dimension. For instance if the i'th element of v_to_ps is [0, 1, 2], then the - # i'th element of _strides is [physical.shape[1] * physical.shape[2], physical.shape[2], 1] - # and so, if we index dimension i with j=j_0 * stride[0] + j_1 * stride[1] + j_2 * stride[2] - # which isa unique decomposition, then this corresponds to indexing dimensions v_to_ps[i] at - # indices [j_0, j_1, j_2] - s = list(self.physical.shape) - strides = [strides_from_p_dims_and_p_shape(dims, s) for dims in self.v_to_ps] + strides = [strides_v2(p_dims, list(self.physical.shape)) for p_dims in self.v_to_ps] # TODO: I think it's ok to create index tensors on CPU when tensor to index is on cuda. Idk # what's faster p_index_ranges = [torch.arange(s) for s in self.physical.shape] p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij") v_indices_grid = list[Tensor]() + all_pdims = list(range(self.physical.ndim)) for stride, dims in zip(strides, self.v_to_ps): stride_ = torch.tensor(stride, dtype=torch.int) - if len(dims) > 0: - v_indices_grid.append( - torch.sum( - torch.stack([p_indices_grid[d] for d in dims], dim=-1) * stride_, dim=-1 - ) + v_indices_grid.append( + torch.sum( + torch.stack([p_indices_grid[d] for d in all_pdims], dim=-1) * stride_, dim=-1 ) - else: - v_indices_grid.append(torch.tensor(0, dtype=torch.int)) - # This is supposed to be a vector of shape d_1 * d_2 ... - # whose elements are the coordinates 1 in p_indices_grad[d_1] times stride 1 - # plus coordinates 2 in p_indices_grad[d_2] times stride 2, etc... + ) res = torch.zeros(self.shape, device=self.physical.device, dtype=self.physical.dtype) res[tuple(v_indices_grid)] = self.physical From b883f2007179d5fa45a90fb2311b805efa34d47d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 31 Oct 2025 07:27:11 +0100 Subject: [PATCH 128/182] Simplify to_dense * Now that we iterate over all_pdims instead of the pdims of the current virtual dimension, the result of torch.stack([p_indices_grid[d] for d in all_pdims], dim=-1) is always the same, and is simply equal to torch.stack(p_indices_grid, dim=-1). So we directly stack the p_indices_grid when creating it, and use the already stacked p_indices_grid in the for-loop. --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 5e6ba598a..69c1ed16b 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -81,17 +81,11 @@ def to_dense( # TODO: I think it's ok to create index tensors on CPU when tensor to index is on cuda. Idk # what's faster p_index_ranges = [torch.arange(s) for s in self.physical.shape] - p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij") + p_indices_grid = torch.stack(torch.meshgrid(*p_index_ranges, indexing="ij"), dim=-1) v_indices_grid = list[Tensor]() - all_pdims = list(range(self.physical.ndim)) for stride, dims in zip(strides, self.v_to_ps): stride_ = torch.tensor(stride, dtype=torch.int) - - v_indices_grid.append( - torch.sum( - torch.stack([p_indices_grid[d] for d in all_pdims], dim=-1) * stride_, dim=-1 - ) - ) + v_indices_grid.append(torch.sum(p_indices_grid * stride_, dim=-1)) res = torch.zeros(self.shape, device=self.physical.device, dtype=self.physical.dtype) res[tuple(v_indices_grid)] = self.physical From 7044e378598d15ddccf0d9bf250f06fe501c4280 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 31 Oct 2025 07:30:37 +0100 Subject: [PATCH 129/182] Remove unused variable dims in the loop of to_dense --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 69c1ed16b..46c66c50c 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -83,7 +83,7 @@ def to_dense( p_index_ranges = [torch.arange(s) for s in self.physical.shape] p_indices_grid = torch.stack(torch.meshgrid(*p_index_ranges, indexing="ij"), dim=-1) v_indices_grid = list[Tensor]() - for stride, dims in zip(strides, self.v_to_ps): + for stride in strides: stride_ = torch.tensor(stride, dtype=torch.int) v_indices_grid.append(torch.sum(p_indices_grid * stride_, dim=-1)) From 384d550386038670f9116c6c54cee3869ba30c0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 31 Oct 2025 07:33:13 +0100 Subject: [PATCH 130/182] Pre-compute strides * In the long term I'll try to rely mostly or even only on them instead of v_to_ps, so it makes sense to pre-compute them. --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 46c66c50c..7e7b998b2 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -66,6 +66,7 @@ def __init__(self, physical: Tensor, v_to_ps: list[list[int]]): self.physical = physical self.v_to_ps = v_to_ps + self.strides = [strides_v2(pdims, list(self.physical.shape)) for pdims in self.v_to_ps] def to_dense( self, dtype: torch.dtype | None = None, *, masked_grad: bool | None = None @@ -76,14 +77,12 @@ def to_dense( if self.physical.ndim == 0: return self.physical - strides = [strides_v2(p_dims, list(self.physical.shape)) for p_dims in self.v_to_ps] - # TODO: I think it's ok to create index tensors on CPU when tensor to index is on cuda. Idk # what's faster p_index_ranges = [torch.arange(s) for s in self.physical.shape] p_indices_grid = torch.stack(torch.meshgrid(*p_index_ranges, indexing="ij"), dim=-1) v_indices_grid = list[Tensor]() - for stride in strides: + for stride in self.strides: stride_ = torch.tensor(stride, dtype=torch.int) v_indices_grid.append(torch.sum(p_indices_grid * stride_, dim=-1)) From 18044dda9d7b26ec668ed83fe31e273bf48c75a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 31 Oct 2025 07:35:49 +0100 Subject: [PATCH 131/182] Replace for-loop with for comprehension to create v_indices_grid, and make it directly as a tuple --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 7e7b998b2..d4b8963f4 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -81,13 +81,10 @@ def to_dense( # what's faster p_index_ranges = [torch.arange(s) for s in self.physical.shape] p_indices_grid = torch.stack(torch.meshgrid(*p_index_ranges, indexing="ij"), dim=-1) - v_indices_grid = list[Tensor]() - for stride in self.strides: - stride_ = torch.tensor(stride, dtype=torch.int) - v_indices_grid.append(torch.sum(p_indices_grid * stride_, dim=-1)) - + strides = [torch.tensor(stride) for stride in self.strides] + v_indices_grid = tuple(torch.sum(p_indices_grid * stride, dim=-1) for stride in strides) res = torch.zeros(self.shape, device=self.physical.device, dtype=self.physical.dtype) - res[tuple(v_indices_grid)] = self.physical + res[v_indices_grid] = self.physical return res @classmethod From b0a0e7a5564b5591d4e5b40945ef34a417451a92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 31 Oct 2025 07:54:08 +0100 Subject: [PATCH 132/182] Remove for-loop in computation of v_indices_grid in to_dense --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index d4b8963f4..b004a012c 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -80,11 +80,11 @@ def to_dense( # TODO: I think it's ok to create index tensors on CPU when tensor to index is on cuda. Idk # what's faster p_index_ranges = [torch.arange(s) for s in self.physical.shape] - p_indices_grid = torch.stack(torch.meshgrid(*p_index_ranges, indexing="ij"), dim=-1) - strides = [torch.tensor(stride) for stride in self.strides] - v_indices_grid = tuple(torch.sum(p_indices_grid * stride, dim=-1) for stride in strides) + p_indices_grid = torch.stack(torch.meshgrid(*p_index_ranges, indexing="ij")) + strides = torch.stack([torch.tensor(stride) for stride in self.strides]) + v_indices_grid = torch.tensordot(strides, p_indices_grid, dims=1) res = torch.zeros(self.shape, device=self.physical.device, dtype=self.physical.dtype) - res[v_indices_grid] = self.physical + res[tuple(v_indices_grid)] = self.physical return res @classmethod From 456adf141281ddd72d0fa9f07245c9da870b6e2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 31 Oct 2025 09:04:25 +0100 Subject: [PATCH 133/182] Add internal strides to debug_info --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index b004a012c..679870afd 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -137,6 +137,7 @@ def debug_info(self) -> str: f"shape: {self.shape}\n" f"stride(): {self.stride()}\n" f"v_to_ps: {self.v_to_ps}\n" + f"strides: {self.strides}\n" f"physical.shape: {self.physical.shape}\n" f"physical.stride(): {self.physical.stride()}" ) From dc37d29c4660782f5ca772b422a6e053be8d6466 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 31 Oct 2025 09:10:38 +0100 Subject: [PATCH 134/182] Simplify creation of strides in to_dense --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 679870afd..3e5b6ea82 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -81,7 +81,7 @@ def to_dense( # what's faster p_index_ranges = [torch.arange(s) for s in self.physical.shape] p_indices_grid = torch.stack(torch.meshgrid(*p_index_ranges, indexing="ij")) - strides = torch.stack([torch.tensor(stride) for stride in self.strides]) + strides = torch.tensor(self.strides) v_indices_grid = torch.tensordot(strides, p_indices_grid, dims=1) res = torch.zeros(self.shape, device=self.physical.device, dtype=self.physical.dtype) res[tuple(v_indices_grid)] = self.physical From a9e298ac1e2d235aa4f5357971296c0825fe6648 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 31 Oct 2025 09:11:44 +0100 Subject: [PATCH 135/182] Remove torch. prefix when possible * This makes lines shorter --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 3e5b6ea82..13ff909cd 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -4,12 +4,12 @@ from math import prod import torch -from torch import Tensor +from torch import Tensor, arange, meshgrid, stack, tensor, tensordot, zeros from torch.ops import aten # type: ignore from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten -class DiagonalSparseTensor(torch.Tensor): +class DiagonalSparseTensor(Tensor): _HANDLED_FUNCTIONS = dict() @staticmethod @@ -79,11 +79,11 @@ def to_dense( # TODO: I think it's ok to create index tensors on CPU when tensor to index is on cuda. Idk # what's faster - p_index_ranges = [torch.arange(s) for s in self.physical.shape] - p_indices_grid = torch.stack(torch.meshgrid(*p_index_ranges, indexing="ij")) - strides = torch.tensor(self.strides) - v_indices_grid = torch.tensordot(strides, p_indices_grid, dims=1) - res = torch.zeros(self.shape, device=self.physical.device, dtype=self.physical.dtype) + p_index_ranges = [arange(s) for s in self.physical.shape] + p_indices_grid = stack(meshgrid(*p_index_ranges, indexing="ij")) + strides = tensor(self.strides) + v_indices_grid = tensordot(strides, p_indices_grid, dims=1) + res = zeros(self.shape, device=self.physical.device, dtype=self.physical.dtype) res[tuple(v_indices_grid)] = self.physical return res @@ -691,12 +691,12 @@ def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> DiagonalSp assert isinstance(t1, DiagonalSparseTensor) or isinstance(t2, DiagonalSparseTensor) if isinstance(t1, int) or isinstance(t1, float): - t1_ = torch.tensor(t1, device=t2.device) + t1_ = tensor(t1, device=t2.device) else: t1_ = t1 if isinstance(t2, int) or isinstance(t2, float): - t2_ = torch.tensor(t2, device=t1.device) + t2_ = tensor(t2, device=t1.device) else: t2_ = t2 From ac2a2c5934a59fb1f1e4fd47c3aca70e288876cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 31 Oct 2025 09:14:42 +0100 Subject: [PATCH 136/182] Create strides tensor in constructor --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 13ff909cd..ec4138c82 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -66,7 +66,11 @@ def __init__(self, physical: Tensor, v_to_ps: list[list[int]]): self.physical = physical self.v_to_ps = v_to_ps - self.strides = [strides_v2(pdims, list(self.physical.shape)) for pdims in self.v_to_ps] + pshape = list(self.physical.shape) + + # TODO: not sure if strides should be always on cpu (e.g. if it's only used for indexing) + # or if we should put it on physical.device. + self.strides = tensor([strides_v2(pdims, pshape) for pdims in self.v_to_ps]) def to_dense( self, dtype: torch.dtype | None = None, *, masked_grad: bool | None = None @@ -81,8 +85,7 @@ def to_dense( # what's faster p_index_ranges = [arange(s) for s in self.physical.shape] p_indices_grid = stack(meshgrid(*p_index_ranges, indexing="ij")) - strides = tensor(self.strides) - v_indices_grid = tensordot(strides, p_indices_grid, dims=1) + v_indices_grid = tensordot(self.strides, p_indices_grid, dims=1) res = zeros(self.shape, device=self.physical.device, dtype=self.physical.dtype) res[tuple(v_indices_grid)] = self.physical return res From cfba7e01e9c7783222b69da65a47c37a3397d8ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 31 Oct 2025 09:20:26 +0100 Subject: [PATCH 137/182] Remove comments about the device of the indices tensors. We actually cannot have strides on cuda because addmm_cuda (required for tensordot) does not support Long tensors (but the cpu version does). --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index ec4138c82..b3d4fa07a 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -67,9 +67,6 @@ def __init__(self, physical: Tensor, v_to_ps: list[list[int]]): self.physical = physical self.v_to_ps = v_to_ps pshape = list(self.physical.shape) - - # TODO: not sure if strides should be always on cpu (e.g. if it's only used for indexing) - # or if we should put it on physical.device. self.strides = tensor([strides_v2(pdims, pshape) for pdims in self.v_to_ps]) def to_dense( @@ -81,10 +78,10 @@ def to_dense( if self.physical.ndim == 0: return self.physical - # TODO: I think it's ok to create index tensors on CPU when tensor to index is on cuda. Idk - # what's faster p_index_ranges = [arange(s) for s in self.physical.shape] p_indices_grid = stack(meshgrid(*p_index_ranges, indexing="ij")) + + # addmm_cuda not implemented for Long tensors => gotta have these tensors on cpu v_indices_grid = tensordot(self.strides, p_indices_grid, dims=1) res = zeros(self.shape, device=self.physical.device, dtype=self.physical.dtype) res[tuple(v_indices_grid)] = self.physical From 2c94488793fe8869a6b4e40d9ed1d43e7995b359 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 31 Oct 2025 09:21:10 +0100 Subject: [PATCH 138/182] Replace self.physical.device and self.physical.dtype by self.device and self.dtype --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index b3d4fa07a..e015cea27 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -83,7 +83,7 @@ def to_dense( # addmm_cuda not implemented for Long tensors => gotta have these tensors on cpu v_indices_grid = tensordot(self.strides, p_indices_grid, dims=1) - res = zeros(self.shape, device=self.physical.device, dtype=self.physical.dtype) + res = zeros(self.shape, device=self.device, dtype=self.dtype) res[tuple(v_indices_grid)] = self.physical return res From 9c1ad5bd5ec534031351bcf086bcb649cb73b9a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 31 Oct 2025 12:58:24 +0100 Subject: [PATCH 139/182] Move unwrap_to_dense out of DiagonalSparseTensor --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index e015cea27..187095a6e 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -94,13 +94,6 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): if func in cls._HANDLED_FUNCTIONS: return cls._HANDLED_FUNCTIONS[func](*args, **kwargs) - # --- Fallback: Fold to Dense Tensor --- - def unwrap_to_dense(t: Tensor): - if isinstance(t, cls): - return t.to_dense() - else: - return t - def tensor_to_str(tensor: Tensor) -> str: result = f"{tensor.__class__.__name__} - shape: {tensor.shape}" if isinstance(tensor, DiagonalSparseTensor): @@ -331,6 +324,13 @@ def to_diagonal_sparse_tensor(t: Tensor) -> DiagonalSparseTensor: return make_dst(t, [[i] for i in range(t.ndim)]) +def unwrap_to_dense(t: Tensor): + if isinstance(t, DiagonalSparseTensor): + return t.to_dense() + else: + return t + + def to_target_physical_strides( physical: Tensor, v_to_ps: list[list[int]], strides: list[list[int]] ) -> tuple[Tensor, list[list[int]]]: From 72527570918cfeddddc3d292edcee355676ea539 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 31 Oct 2025 13:06:33 +0100 Subject: [PATCH 140/182] Extract print_fallback --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 51 ++++++++++--------- 1 file changed, 27 insertions(+), 24 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 187095a6e..991458681 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -94,30 +94,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): if func in cls._HANDLED_FUNCTIONS: return cls._HANDLED_FUNCTIONS[func](*args, **kwargs) - def tensor_to_str(tensor: Tensor) -> str: - result = f"{tensor.__class__.__name__} - shape: {tensor.shape}" - if isinstance(tensor, DiagonalSparseTensor): - result += f" - pshape: {tensor.physical.shape} - v_to_ps: {tensor.v_to_ps}" - - return result - - print(f"Falling back to dense for {func.__name__}") - if len(args) > 0: - print("* args:") - for arg in args: - if isinstance(arg, Tensor): - print(f" > {tensor_to_str(arg)}") - elif isinstance(arg, list) and len(arg) > 0 and isinstance(arg[0], Tensor): - list_content = "\n ".join([tensor_to_str(t) for t in arg]) - print(f" > [{list_content}]") - else: - print(f" > {arg}") - if len(kwargs) > 0: - print("* kwargs:") - for k, v in kwargs.items(): - print(f" > {k}: {v}") - print() - + print_fallback(func, args, kwargs) unwrapped_args = tree_map(unwrap_to_dense, args) unwrapped_kwargs = tree_map(unwrap_to_dense, kwargs) return func(*unwrapped_args, **unwrapped_kwargs) @@ -148,6 +125,32 @@ def decorator(func): return decorator +def print_fallback(func, args, kwargs) -> None: + def tensor_to_str(t: Tensor) -> str: + result = f"{t.__class__.__name__} - shape: {t.shape}" + if isinstance(t, DiagonalSparseTensor): + result += f" - pshape: {t.physical.shape} - v_to_ps: {t.v_to_ps}" + + return result + + print(f"Falling back to dense for {func.__name__}") + if len(args) > 0: + print("* args:") + for arg in args: + if isinstance(arg, Tensor): + print(f" > {tensor_to_str(arg)}") + elif isinstance(arg, list) and len(arg) > 0 and isinstance(arg[0], Tensor): + list_content = "\n ".join([tensor_to_str(t) for t in arg]) + print(f" > [{list_content}]") + else: + print(f" > {arg}") + if len(kwargs) > 0: + print("* kwargs:") + for k, v in kwargs.items(): + print(f" > {k}: {v}") + print() + + def strides_from_p_dims_and_p_shape(p_dims: list[int], physical_shape: list[int]) -> list[int]: return list(accumulate([1] + [physical_shape[dim] for dim in p_dims[:0:-1]], operator.mul))[ ::-1 From 7b4c784122a464bc590c336c4b02f8c7c0d65871 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 31 Oct 2025 13:27:47 +0100 Subject: [PATCH 141/182] Add placeholder cat_default --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 991458681..01c81a35e 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -472,6 +472,17 @@ def unsqueeze_default(t: DiagonalSparseTensor, dim: int) -> DiagonalSparseTensor return DiagonalSparseTensor(t.physical, new_v_to_ps) +@DiagonalSparseTensor.implements(aten.cat.default) +def cat_default(tensors: list[Tensor], dim: int) -> Tensor: + if any(not isinstance(t, DiagonalSparseTensor) for t in tensors): + print_fallback(aten.cat.default, (tensors, dim), {}) + return aten.cat.default([unwrap_to_dense(t) for t in tensors]) + + else: + # TODO: efficient implementation when all tensors are sparse + return aten.cat.default([unwrap_to_dense(t) for t in tensors]) + + def unsquash_pdim( physical: Tensor, pdim: int, new_pdim_shape: list[int] ) -> tuple[Tensor, list[list[int]]]: From 767281c0ab2d363b68793a96ecf02fe80215b356 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 31 Oct 2025 13:28:14 +0100 Subject: [PATCH 142/182] Add to_most_efficient_tensor, use it in view and einsum --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 01c81a35e..c132a4cf2 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -327,6 +327,17 @@ def to_diagonal_sparse_tensor(t: Tensor) -> DiagonalSparseTensor: return make_dst(t, [[i] for i in range(t.ndim)]) +def to_most_efficient_tensor(physical: Tensor, v_to_ps: list[list[int]]) -> Tensor: + physical, v_to_ps = fix_dim_encoding(physical, v_to_ps) + physical, v_to_ps = fix_dim_of_size_1(physical, v_to_ps) + physical, v_to_ps = fix_ungrouped_dims(physical, v_to_ps) + + if sum([len(pdims) for pdims in v_to_ps]) == physical.ndim: + return torch.movedim(physical, list(range(physical.ndim)), [pdims[0] for pdims in v_to_ps]) + else: + return DiagonalSparseTensor(physical, v_to_ps) + + def unwrap_to_dense(t: Tensor): if isinstance(t, DiagonalSparseTensor): return t.to_dense() @@ -514,7 +525,7 @@ def infer_shape(shape: list[int], numel: int) -> list[int]: @DiagonalSparseTensor.implements(aten.view.default) -def view_default(t: DiagonalSparseTensor, shape: list[int]) -> DiagonalSparseTensor: +def view_default(t: DiagonalSparseTensor, shape: list[int]) -> Tensor: assert isinstance(t, DiagonalSparseTensor) shape = infer_shape(shape, t.numel()) @@ -574,11 +585,11 @@ def view_default(t: DiagonalSparseTensor, shape: list[int]) -> DiagonalSparseTen # The above code does not handle physical dimension squashing, so the physical is not # necessarily maximally squashed at this point, so we need the safe constructor. - return make_dst(new_physical, new_v_to_ps) + return to_most_efficient_tensor(new_physical, new_v_to_ps) @DiagonalSparseTensor.implements(aten._unsafe_view.default) -def _unsafe_view_default(t: DiagonalSparseTensor, shape: list[int]) -> DiagonalSparseTensor: +def _unsafe_view_default(t: DiagonalSparseTensor, shape: list[int]) -> Tensor: return view_default( t, shape ) # We don't do the optimizations that they do in https://github.com/pytorch/pytorch/blame/main/aten/src/ATen/native/TensorShape.cpp @@ -744,9 +755,7 @@ def transpose_int(t: DiagonalSparseTensor, dim0: int, dim1: int) -> DiagonalSpar return DiagonalSparseTensor(new_physical, new_v_to_ps) -def einsum( - *args: tuple[DiagonalSparseTensor, list[int]], output: list[int] -) -> DiagonalSparseTensor: +def einsum(*args: tuple[DiagonalSparseTensor, list[int]], output: list[int]) -> Tensor: # First part of the algorithm, determine how to cluster physical indices as well as the common # p_shapes corresponding to matching v_dims. Second part translates to physical einsum. @@ -853,7 +862,7 @@ def unique_int(pair: tuple[int, int]) -> int: physical = torch.einsum(*[x for y in zip(tensors, new_indices) for x in y], new_output) # Need to use the safe constructor, otherwise the dimensions may not be maximally grouped. # Maybe there is a way to fix that though. - return make_dst(physical, v_to_ps) + return to_most_efficient_tensor(physical, v_to_ps) @DiagonalSparseTensor.implements(aten.bmm.default) From 6a3145fbda1a48025f74cae7e23457cb2aa5c14d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 31 Oct 2025 13:35:16 +0100 Subject: [PATCH 143/182] Add permute_default --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index c132a4cf2..b215a5eeb 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -483,6 +483,14 @@ def unsqueeze_default(t: DiagonalSparseTensor, dim: int) -> DiagonalSparseTensor return DiagonalSparseTensor(t.physical, new_v_to_ps) +@DiagonalSparseTensor.implements(aten.permute.default) +def permute_default(t: DiagonalSparseTensor, dims: list[int]) -> DiagonalSparseTensor: + new_v_to_ps = [t.v_to_ps[d] for d in dims] + + new_physical, new_v_to_ps = fix_dim_encoding(t.physical, new_v_to_ps) + return DiagonalSparseTensor(new_physical, new_v_to_ps) + + @DiagonalSparseTensor.implements(aten.cat.default) def cat_default(tensors: list[Tensor], dim: int) -> Tensor: if any(not isinstance(t, DiagonalSparseTensor) for t in tensors): From 9f1860a39a7c163047a452fe54b9385a82940e6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 31 Oct 2025 13:37:46 +0100 Subject: [PATCH 144/182] Fix some return type hints --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index b215a5eeb..1ed0dccd2 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -719,7 +719,7 @@ def slice_Tensor( @DiagonalSparseTensor.implements(aten.mul.Tensor) -def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> DiagonalSparseTensor: +def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: # Element-wise multiplication with broadcasting assert isinstance(t1, DiagonalSparseTensor) or isinstance(t2, DiagonalSparseTensor) @@ -874,7 +874,7 @@ def unique_int(pair: tuple[int, int]) -> int: @DiagonalSparseTensor.implements(aten.bmm.default) -def bmm_default(mat1: Tensor, mat2: Tensor) -> DiagonalSparseTensor: +def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor: assert isinstance(mat1, DiagonalSparseTensor) or isinstance(mat2, DiagonalSparseTensor) assert ( mat1.ndim == 3 @@ -892,7 +892,7 @@ def bmm_default(mat1: Tensor, mat2: Tensor) -> DiagonalSparseTensor: @DiagonalSparseTensor.implements(aten.mm.default) -def mm_default(mat1: Tensor, mat2: Tensor) -> DiagonalSparseTensor: +def mm_default(mat1: Tensor, mat2: Tensor) -> Tensor: assert isinstance(mat1, DiagonalSparseTensor) or isinstance(mat2, DiagonalSparseTensor) assert mat1.ndim == 2 and mat2.ndim == 2 and mat1.shape[1] == mat2.shape[0] From 137f64e20dffa20d941ed998f81f8eccfe927e07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 31 Oct 2025 13:50:41 +0100 Subject: [PATCH 145/182] Add div_Tensor, factorize prepare_for_elementwise_op --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 26 ++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 1ed0dccd2..0c226169d 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -718,9 +718,14 @@ def slice_Tensor( return DiagonalSparseTensor(new_physical, t.v_to_ps) -@DiagonalSparseTensor.implements(aten.mul.Tensor) -def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: - # Element-wise multiplication with broadcasting +def prepare_for_elementwise_op( + t1: Tensor | int | float, t2: Tensor | int | float +) -> tuple[DiagonalSparseTensor, DiagonalSparseTensor]: + """ + Prepares two DSTs of the same shape from two args, one of those being a DST, and the other being + a DST, Tensor, int or float. + """ + assert isinstance(t1, DiagonalSparseTensor) or isinstance(t2, DiagonalSparseTensor) if isinstance(t1, int) or isinstance(t1, float): @@ -737,6 +742,21 @@ def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: t1_ = to_diagonal_sparse_tensor(t1_) t2_ = to_diagonal_sparse_tensor(t2_) + return t1_, t2_ + + +@DiagonalSparseTensor.implements(aten.mul.Tensor) +def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: + # Element-wise multiplication with broadcasting + t1_, t2_ = prepare_for_elementwise_op(t1, t2) + all_dims = list(range(t1_.ndim)) + return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) + + +@DiagonalSparseTensor.implements(aten.div.Tensor) +def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: + t1_, t2_ = prepare_for_elementwise_op(t1, t2) + t2_ = DiagonalSparseTensor(1.0 / t2_.physical, t2_.v_to_ps) all_dims = list(range(t1_.ndim)) return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) From dc4c1a54bb23341b834a5ca235192936eb3433ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 31 Oct 2025 13:59:38 +0100 Subject: [PATCH 146/182] Add add_Tensor for same v_to_ps --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 0c226169d..65fb29a70 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -771,6 +771,19 @@ def mul_Scalar(t: DiagonalSparseTensor, scalar) -> DiagonalSparseTensor: return DiagonalSparseTensor(new_physical, t.v_to_ps) +@DiagonalSparseTensor.implements(aten.add.Tensor) +def add_Tensor( + t1: Tensor | int | float, t2: Tensor | int | float, alpha: Tensor | float = 1.0 +) -> DiagonalSparseTensor: + t1_, t2_ = prepare_for_elementwise_op(t1, t2) + + if t1_.v_to_ps == t2_.v_to_ps: + new_physical = t1_.physical + t2_.physical * alpha + return DiagonalSparseTensor(new_physical, t1_.v_to_ps) + else: + raise NotImplementedError() + + @DiagonalSparseTensor.implements(aten.transpose.int) def transpose_int(t: DiagonalSparseTensor, dim0: int, dim1: int) -> DiagonalSparseTensor: assert isinstance(t, DiagonalSparseTensor) From f057d7e959bb4176ebbb9f26fcd30a3e243e8179 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 1 Nov 2025 08:59:56 +0100 Subject: [PATCH 147/182] Add squeeze_dims --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 65fb29a70..55b538c34 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -483,6 +483,22 @@ def unsqueeze_default(t: DiagonalSparseTensor, dim: int) -> DiagonalSparseTensor return DiagonalSparseTensor(t.physical, new_v_to_ps) +@DiagonalSparseTensor.implements(aten.squeeze.dims) +def squeeze_dims(t: DiagonalSparseTensor, dims: list[int] | int | None) -> Tensor: + assert isinstance(t, DiagonalSparseTensor) + + if dims is None: + excluded = set(range(t.ndim)) + elif isinstance(dims, int): + excluded = {dims} + else: + excluded = set(dims) + + new_v_to_ps = [pdims for i, pdims in enumerate(t.v_to_ps) if i not in excluded] + + return to_most_efficient_tensor(t.physical, new_v_to_ps) + + @DiagonalSparseTensor.implements(aten.permute.default) def permute_default(t: DiagonalSparseTensor, dims: list[int]) -> DiagonalSparseTensor: new_v_to_ps = [t.v_to_ps[d] for d in dims] From dc356966b96d71e4a024fba548d5224409345add Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 1 Nov 2025 09:14:15 +0100 Subject: [PATCH 148/182] Handle cases where pdims is empty list in the creation of normal tensor in to_most_efficient_tensor * The idea is to add as many physical dimensions at the end of the physical tensor (unsqueeze(-1)) as needed, and the create the corresponding new_v_to_ps. --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 55b538c34..9cdc8d85c 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -333,7 +333,22 @@ def to_most_efficient_tensor(physical: Tensor, v_to_ps: list[list[int]]) -> Tens physical, v_to_ps = fix_ungrouped_dims(physical, v_to_ps) if sum([len(pdims) for pdims in v_to_ps]) == physical.ndim: - return torch.movedim(physical, list(range(physical.ndim)), [pdims[0] for pdims in v_to_ps]) + next_physical_index = physical.ndim + new_v_to_ps = [] + # Add as many dimensions of size 1 as there are pdims equal to [] in v_to_ps. + # Create the corresponding new_v_to_ps. + # E.g. if v_to_ps is [[0], [], [1]], new_v_to_ps is [[0], [2], [1]]. + for vdim, pdims in enumerate(v_to_ps): + if len(pdims) == 0: + physical = physical.unsqueeze(-1) + new_v_to_ps.append([next_physical_index]) + next_physical_index += 1 + else: + new_v_to_ps.append(pdims) + + return torch.movedim( + physical, list(range(physical.ndim)), [pdims[0] for pdims in new_v_to_ps] + ) else: return DiagonalSparseTensor(physical, v_to_ps) From c97612f8d7f9bd2128c20ddbe03cd267dbf10f92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 1 Nov 2025 09:29:52 +0100 Subject: [PATCH 149/182] Add hardtanh_backward_default --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 9cdc8d85c..f2202dc52 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -713,6 +713,20 @@ def threshold_backward_default( return DiagonalSparseTensor(new_physical, grad_output.v_to_ps) +@DiagonalSparseTensor.implements(aten.hardtanh_backward.default) +def hardtanh_backward_default( + grad_output: DiagonalSparseTensor, + self: Tensor, + min_val: Tensor | int | float, + max_val: Tensor | int | float, +) -> DiagonalSparseTensor: + if isinstance(self, DiagonalSparseTensor): + raise NotImplementedError() + + new_physical = aten.hardtanh_backward.default(grad_output.physical, self, min_val, max_val) + return DiagonalSparseTensor(new_physical, grad_output.v_to_ps) + + @DiagonalSparseTensor.implements(aten.slice.Tensor) def slice_Tensor( t: DiagonalSparseTensor, dim: int, start: int | None, end: int | None, step: int = 1 From 56b66ba9d624cd77cede465135fce0443e132e4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 1 Nov 2025 09:34:05 +0100 Subject: [PATCH 150/182] Add hardswish_backward_default * It seems like we could factorize code to handle all activation function backwards easily, and maybe even all functions that can be partialled into a pointwise function. --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index f2202dc52..7a5a0f742 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -727,6 +727,15 @@ def hardtanh_backward_default( return DiagonalSparseTensor(new_physical, grad_output.v_to_ps) +@DiagonalSparseTensor.implements(aten.hardswish_backward.default) +def hardswish_backward_default(grad_output: DiagonalSparseTensor, self: Tensor): + if isinstance(self, DiagonalSparseTensor): + raise NotImplementedError() + + new_physical = aten.hardswish_backward.default(grad_output.physical, self) + return DiagonalSparseTensor(new_physical, grad_output.v_to_ps) + + @DiagonalSparseTensor.implements(aten.slice.Tensor) def slice_Tensor( t: DiagonalSparseTensor, dim: int, start: int | None, end: int | None, step: int = 1 From 10f250cadd96a8e05b9114d074431ba4600860a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 1 Nov 2025 14:01:51 +0100 Subject: [PATCH 151/182] Add more concatenate tests --- tests/unit/sparse/test_diagonal_sparse_tensor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index b896fcd2a..f4642846f 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -261,6 +261,10 @@ def test_unsquash_pdim( @mark.parametrize( ["dst_args", "dim"], [ + ([([3], [[0], [0]]), ([3], [[0], [0]])], 1), + ([([3, 2], [[0], [0, 1]]), ([3], [[0], [0]])], 1), + ([([3], [[0], [0]]), ([3, 2], [[0], [0, 1]])], 1), + ([([3, 2], [[0], [0, 1]]), ([3, 2], [[0], [0, 1]])], 1), ([([3, 4], [[0], [0, 1]]), ([3, 3, 4], [[0, 1], [1, 2]])], 0), ([([3, 12], [[0, 1], [0]]), ([9, 4], [[0, 1], [0]])], 1), ], From 48de1872eca67163eae5117424d0868476df37b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 2 Nov 2025 03:38:09 +0100 Subject: [PATCH 152/182] Add comment about strides --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 7a5a0f742..889445693 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -66,6 +66,8 @@ def __init__(self, physical: Tensor, v_to_ps: list[list[int]]): self.physical = physical self.v_to_ps = v_to_ps + + # strides is of shape [v_ndim, p_ndim], such that v_index = strides @ p_index pshape = list(self.physical.shape) self.strides = tensor([strides_v2(pdims, pshape) for pdims in self.v_to_ps]) From b60a29da45a3df020e0a7d63dbba3b2c18990b25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 2 Nov 2025 04:03:32 +0100 Subject: [PATCH 153/182] unsquash_pdim_from_strides --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 889445693..b2cd0385b 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -535,6 +535,17 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor: return aten.cat.default([unwrap_to_dense(t) for t in tensors]) +def unsquash_pdim_from_strides( + physical: Tensor, pdim: int, new_pdim_shape: list[int] +) -> tuple[Tensor, Tensor]: + new_shape = list(physical.shape) + new_shape = new_shape[:pdim] + new_pdim_shape + new_shape[pdim + 1 :] + new_physical = physical.reshape(new_shape) + + stride_multipliers = tensor([prod(new_pdim_shape[i + 1 :]) for i in range(len(new_pdim_shape))]) + return new_physical, stride_multipliers + + def unsquash_pdim( physical: Tensor, pdim: int, new_pdim_shape: list[int] ) -> tuple[Tensor, list[list[int]]]: From 9ede7ecd01b22613fbec64c6bea54140be454382 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 2 Nov 2025 05:20:47 +0100 Subject: [PATCH 154/182] Simplify test_concatenate params --- tests/unit/sparse/test_diagonal_sparse_tensor.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index f4642846f..e84a22d6f 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -262,11 +262,7 @@ def test_unsquash_pdim( ["dst_args", "dim"], [ ([([3], [[0], [0]]), ([3], [[0], [0]])], 1), - ([([3, 2], [[0], [0, 1]]), ([3], [[0], [0]])], 1), - ([([3], [[0], [0]]), ([3, 2], [[0], [0, 1]])], 1), - ([([3, 2], [[0], [0, 1]]), ([3, 2], [[0], [0, 1]])], 1), - ([([3, 4], [[0], [0, 1]]), ([3, 3, 4], [[0, 1], [1, 2]])], 0), - ([([3, 12], [[0, 1], [0]]), ([9, 4], [[0, 1], [0]])], 1), + ([([3, 2], [[0], [1, 0]]), ([3, 2], [[0], [1, 0]])], 1), ], ) def test_concatenate( From 5a25233c2499daff84cd1eaa84a025813693e81b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 2 Nov 2025 05:21:27 +0100 Subject: [PATCH 155/182] Add basic implementation of cat_default for when all strides match and the pdim on which we concatenate already exists. --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index b2cd0385b..05aa5998d 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -2,6 +2,7 @@ from functools import wraps from itertools import accumulate from math import prod +from typing import cast import torch from torch import Tensor, arange, meshgrid, stack, tensor, tensordot, zeros @@ -530,9 +531,31 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor: print_fallback(aten.cat.default, (tensors, dim), {}) return aten.cat.default([unwrap_to_dense(t) for t in tensors]) - else: - # TODO: efficient implementation when all tensors are sparse - return aten.cat.default([unwrap_to_dense(t) for t in tensors]) + tensors_ = [cast(DiagonalSparseTensor, t) for t in tensors] + ref_tensor = tensors_[0] + ref_strides = ref_tensor.strides + if any(not torch.equal(t.strides, ref_strides) for t in tensors_[1:]): + raise NotImplementedError() + + # We need to try to find the (pretty sure it either does not exist or is unique) physical + # dimension that makes us only move on virtual dimension dim. It also needs to be such that + # traversing it entirely brings us exactly to the end of virtual dimension dim. + + ref_virtual_dim_size = ref_tensor.shape[dim] + indices = torch.argwhere( + torch.eq(ref_strides[dim] * tensor(ref_tensor.physical.shape), ref_virtual_dim_size) + & torch.eq(ref_strides.sum(dim=0) * tensor(ref_tensor.physical.shape), ref_virtual_dim_size) + ) + assert len(indices) <= 1 + + if len(indices) == 0: + # TODO: create new physical dimension on which we'll concatenate + raise NotImplementedError() + + pdim = indices[0][0] + + new_physical = aten.cat.default([t.physical for t in tensors_], dim=pdim) + return DiagonalSparseTensor(new_physical, ref_tensor.v_to_ps) def unsquash_pdim_from_strides( From 5062b06c97d51e5d648c96ee0e043af2f42455fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 2 Nov 2025 06:03:50 +0100 Subject: [PATCH 156/182] Add concat implementation for when the physical dimension on which to concatenate has to be added --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 05aa5998d..42d792d20 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -549,13 +549,28 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor: assert len(indices) <= 1 if len(indices) == 0: - # TODO: create new physical dimension on which we'll concatenate - raise NotImplementedError() - - pdim = indices[0][0] + # Add a physical dimension pdim on which we can concatenate the physicals such that this + # translates into a concatenation of the virtuals on virtual dimension dim. + + # Stride-based representation: + # new_stride_column = torch.zeros(ref_tensor.ndim, dtype=torch.int) + # new_stride_column[dim] = ref_virtual_dim_size + + pdim = ref_tensor.physical.ndim + new_v_to_ps = [[d for d in pdims] for pdims in ref_tensor.v_to_ps] + new_v_to_ps[dim] = [pdim] + new_v_to_ps[dim] + new_v_to_ps, destination = encode_v_to_ps(new_v_to_ps) + source = list(range(len(destination))) + physicals = [t.physical.unsqueeze(-1).movedim(source, destination) for t in tensors_] + else: + # Such a physical dimension already exists. Note that an alternative implementation would be + # to simply always add the physical dimension, and squash it if it ends up being not needed. + physicals = [t.physical for t in tensors_] + pdim = indices[0][0] + new_v_to_ps = ref_tensor.v_to_ps - new_physical = aten.cat.default([t.physical for t in tensors_], dim=pdim) - return DiagonalSparseTensor(new_physical, ref_tensor.v_to_ps) + new_physical = aten.cat.default(physicals, dim=pdim) + return DiagonalSparseTensor(new_physical, new_v_to_ps) def unsquash_pdim_from_strides( From 3b9a9a89b73fc82bc6fee7a5b1108674f3ddec98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 2 Nov 2025 15:11:37 +0100 Subject: [PATCH 157/182] Revamp grouping detection --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 106 +++++++++++++----- .../sparse/test_diagonal_sparse_tensor.py | 19 ++-- 2 files changed, 84 insertions(+), 41 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 42d792d20..acb0ffc4a 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -1,3 +1,4 @@ +import itertools import operator from functools import wraps from itertools import accumulate @@ -62,15 +63,14 @@ def __init__(self, physical: Tensor, v_to_ps: list[list[int]]): f"v_to_ps elements are not encoded by first appearance. Found {v_to_ps}." ) - if any(len(group) != 1 for group in get_groupings(v_to_ps)): - raise ValueError(f"Dimensions must be maximally grouped. Found {v_to_ps}.") - self.physical = physical self.v_to_ps = v_to_ps # strides is of shape [v_ndim, p_ndim], such that v_index = strides @ p_index - pshape = list(self.physical.shape) - self.strides = tensor([strides_v2(pdims, pshape) for pdims in self.v_to_ps]) + self.strides = get_strides(list(self.physical.shape), v_to_ps) + + if any(len(group) != 1 for group in get_groupings_generalized(self.strides)): + raise ValueError(f"Dimensions must be maximally grouped. Found {v_to_ps}.") def to_dense( self, dtype: torch.dtype | None = None, *, masked_grad: bool | None = None @@ -188,11 +188,18 @@ def strides_v2(p_dims: list[int], physical_shape: list[int]) -> list[int]: return result +def get_strides(pshape: list[int], v_to_ps: list[list[int]]) -> Tensor: + strides = torch.tensor([strides_v2(pdims, pshape) for pdims in v_to_ps], dtype=torch.int64) + + # It's sometimes necessary to reshape: when v_to_ps contains 0 element for instance. + return strides.reshape(len(v_to_ps), len(pshape)) + + def argmax(iterable): return max(enumerate(iterable), key=lambda x: x[1])[0] -def strides_to_pdims(strides: list[int], physical_shape: list[int]) -> list[int]: +def strides_to_pdims(strides: Tensor, physical_shape: list[int]) -> list[int]: """ Given a list of strides, find and return the used physical dimensions. @@ -207,7 +214,7 @@ def strides_to_pdims(strides: list[int], physical_shape: list[int]) -> list[int] # e.g. strides = [22111, 201000], physical_shape = [10, 2] pdims = [] - remaining_strides = [s for s in strides] + remaining_strides = strides.clone() remaining_numel = ( sum(remaining_strides[i] * (physical_shape[i] - 1) for i in range(len(physical_shape))) + 1 ) @@ -253,29 +260,62 @@ def p_to_vs_from_v_to_ps(v_to_ps: list[list[int]]) -> list[list[tuple[int, int]] return [res[i] for i in range(len(res))] -def get_groupings(v_to_ps: list[list[int]]) -> list[list[int]]: - """Example: [[0, 1, 2], [2, 0, 1], [2]] => [[0, 1], [2]]""" +def are_ratios_matching(v: Tensor) -> bool: + # Returns a boolean indicating whether all non-nan values in a vector are integer and equal to + # each other. + # Returns a scalar boolean tensor indicating whether all values in v are the same or nan: + # [3.0, nan, 3.0] => True + # [nan, nan, nan] => True + # [3.0, nan, 2.0] => False + # [0.5, 0.5, 0.5] => False - mapping = dict[int, list[int]]() + non_nan_values = v[~v.isnan()] + return ( + torch.eq(non_nan_values.int(), non_nan_values).all().item() + and non_nan_values.eq(non_nan_values[0:1]).all().item() + ) - for p_dims in v_to_ps: - for i, p_dim in enumerate(p_dims): - if p_dim not in mapping: - mapping[p_dim] = p_dims[i:] - else: - mapping[p_dim] = longest_common_prefix(mapping[p_dim], p_dims[i:]) - groups = [] - visited_is = set() - for i, group in mapping.items(): - if i in visited_is: - continue +def get_groupings_generalized(strides: Tensor) -> list[list[int]]: + fstrides = strides.to(dtype=torch.float64) + # Note that float64 has 53 bits of precision, meaning that every integer number up to 2^53 can + # be represented on a float64 without any numerical error. Since strides are stored on int64, + # ratios can be of up to 2^64. This function may thus fail for stride values between 2^53 and + # 2^64. + + ratios = torch.div(fstrides.unsqueeze(2), fstrides.unsqueeze(1)) + + # Mapping from column id to the set of columns with which it can be grouped + groups = {i: {i} for i, column in enumerate(strides.T)} + for i1, i2 in itertools.permutations(range(strides.shape[1]), 2): + if are_ratios_matching(ratios[:, i1, i2]): + groups[i1].update(groups[i2]) + groups[i2].update(groups[i1]) + + new_columns = [] + for i, group in groups.items(): + sorted_group = sorted(list(group)) + if i == sorted_group[0]: # This ensures that the same group is added only once + new_columns.append(sorted_group) + + return new_columns - available_dims = set(group) - visited_is - groups.append(list(available_dims)) - visited_is.update(set(group)) - return groups +def get_groupings(pshape: list[int], strides: Tensor) -> list[list[int]]: + strides_time_pshape = strides * tensor(pshape) + groups = {i: {i} for i, column in enumerate(strides.T)} + group_ids = [i for i in range(len(strides.T))] + for i1, i2 in itertools.combinations(range(strides.shape[1]), 2): + if torch.equal(strides[:, i1], strides_time_pshape[:, i2]): + groups[group_ids[i1]].update(groups[group_ids[i2]]) + group_ids[i2] = group_ids[i1] + + new_columns = [sorted(groups[group_id]) for group_id in sorted(set(group_ids))] + + if len(new_columns) != len(pshape): + print(f"Combined pshape with the following new columns: {new_columns}.") + + return new_columns def longest_common_prefix(l1: list[int], l2: list[int]) -> list[int]: @@ -413,12 +453,16 @@ def new_encoding(d: int) -> int: def fix_ungrouped_dims( physical: Tensor, v_to_ps: list[list[int]] ) -> tuple[Tensor, list[list[int]]]: - groups = get_groupings(v_to_ps) - physical = physical.reshape([prod([physical.shape[dim] for dim in group]) for group in groups]) - mapping = {group[0]: i for i, group in enumerate(groups)} - new_v_to_ps = [[mapping[i] for i in dims if i in mapping] for dims in v_to_ps] - - return physical, new_v_to_ps + strides = get_strides(list(physical.shape), v_to_ps) + groups = get_groupings(list(physical.shape), strides) + nphysical = physical.reshape([prod([physical.shape[dim] for dim in group]) for group in groups]) + stride_mapping = torch.zeros(physical.ndim, nphysical.ndim, dtype=torch.int64) + for j, group in enumerate(groups): + stride_mapping[group[-1], j] = 1 + + new_strides = strides @ stride_mapping + new_v_to_ps = [strides_to_pdims(stride, list(nphysical.shape)) for stride in new_strides] + return nphysical, new_v_to_ps def make_dst(physical: Tensor, v_to_ps: list[list[int]]) -> DiagonalSparseTensor: diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index e84a22d6f..e59f7d2c5 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -194,19 +194,18 @@ def test_encode_by_order( @mark.parametrize( - ["v_to_ps", "expected_groupings"], + ["pshape", "strides", "expected"], [ - ([[0, 1, 2], [2, 0, 1], [2]], [[0, 1], [2]]), - ([[0, 1, 0, 1]], [[0, 1]]), - ([[0, 1, 0, 1, 2]], [[0, 1], [2]]), - ([[0, 0]], [[0, 0]]), - ([[0, 1], [1, 2]], [[0], [1], [2]]), + ( + [[32, 2, 3, 4, 5]], + torch.tensor([[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0], [0, 60, 20, 5, 1]]), + [[0], [1, 2, 3, 4]], + ) ], ) -def test_get_groupings(v_to_ps: list[list[int]], expected_groupings: list[list[int]]): - groupings = get_groupings(v_to_ps) - - assert groupings == expected_groupings +def test_get_groupings(pshape: list[int], strides: torch.Tensor, expected: list[list[int]]): + result = get_groupings(pshape, strides) + assert result == expected @mark.parametrize( From ab15e1aede830b423d274891fd94c88b214014ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 2 Nov 2025 17:35:46 +0100 Subject: [PATCH 158/182] Remove get_groupings_generalized * This function could be useful in the future if for example we want to merge two physical dimensions that have some overlap in the virtual dimension. --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 43 +------------------ 1 file changed, 1 insertion(+), 42 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index acb0ffc4a..457e6169b 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -69,7 +69,7 @@ def __init__(self, physical: Tensor, v_to_ps: list[list[int]]): # strides is of shape [v_ndim, p_ndim], such that v_index = strides @ p_index self.strides = get_strides(list(self.physical.shape), v_to_ps) - if any(len(group) != 1 for group in get_groupings_generalized(self.strides)): + if any(len(group) != 1 for group in get_groupings(list(self.physical.shape), self.strides)): raise ValueError(f"Dimensions must be maximally grouped. Found {v_to_ps}.") def to_dense( @@ -260,47 +260,6 @@ def p_to_vs_from_v_to_ps(v_to_ps: list[list[int]]) -> list[list[tuple[int, int]] return [res[i] for i in range(len(res))] -def are_ratios_matching(v: Tensor) -> bool: - # Returns a boolean indicating whether all non-nan values in a vector are integer and equal to - # each other. - # Returns a scalar boolean tensor indicating whether all values in v are the same or nan: - # [3.0, nan, 3.0] => True - # [nan, nan, nan] => True - # [3.0, nan, 2.0] => False - # [0.5, 0.5, 0.5] => False - - non_nan_values = v[~v.isnan()] - return ( - torch.eq(non_nan_values.int(), non_nan_values).all().item() - and non_nan_values.eq(non_nan_values[0:1]).all().item() - ) - - -def get_groupings_generalized(strides: Tensor) -> list[list[int]]: - fstrides = strides.to(dtype=torch.float64) - # Note that float64 has 53 bits of precision, meaning that every integer number up to 2^53 can - # be represented on a float64 without any numerical error. Since strides are stored on int64, - # ratios can be of up to 2^64. This function may thus fail for stride values between 2^53 and - # 2^64. - - ratios = torch.div(fstrides.unsqueeze(2), fstrides.unsqueeze(1)) - - # Mapping from column id to the set of columns with which it can be grouped - groups = {i: {i} for i, column in enumerate(strides.T)} - for i1, i2 in itertools.permutations(range(strides.shape[1]), 2): - if are_ratios_matching(ratios[:, i1, i2]): - groups[i1].update(groups[i2]) - groups[i2].update(groups[i1]) - - new_columns = [] - for i, group in groups.items(): - sorted_group = sorted(list(group)) - if i == sorted_group[0]: # This ensures that the same group is added only once - new_columns.append(sorted_group) - - return new_columns - - def get_groupings(pshape: list[int], strides: Tensor) -> list[list[int]]: strides_time_pshape = strides * tensor(pshape) groups = {i: {i} for i, column in enumerate(strides.T)} From 26de009c5d33ff9e593cd711b4d740e99a3b3c5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 2 Nov 2025 17:36:41 +0100 Subject: [PATCH 159/182] Remove longest_common_prefix --- src/torchjd/sparse/_diagonal_sparse_tensor.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index 457e6169b..a09798648 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -277,16 +277,6 @@ def get_groupings(pshape: list[int], strides: Tensor) -> list[list[int]]: return new_columns -def longest_common_prefix(l1: list[int], l2: list[int]) -> list[int]: - prefix = [] - for a, b in zip(l1, l2, strict=False): - if a == b: - prefix.append(a) - else: - break - return prefix - - def encode_by_order(input: list[int]) -> tuple[list[int], list[int]]: """ Encodes values based on the order of their first appearance, starting at 0 and incrementing. From 2419c7e5a9849e396cecf5a6ca1457f76100fa3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 2 Nov 2025 18:00:32 +0100 Subject: [PATCH 160/182] Restructure sparse package --- src/torchjd/sparse/__init__.py | 2 + .../_aten_function_overrides/__init__.py | 1 + .../_aten_function_overrides/backward.py | 36 + .../sparse/_aten_function_overrides/einsum.py | 246 ++++++ .../_aten_function_overrides/pointwise.py | 125 ++++ .../sparse/_aten_function_overrides/shape.py | 323 ++++++++ src/torchjd/sparse/_diagonal_sparse_tensor.py | 701 ------------------ .../sparse/test_diagonal_sparse_tensor.py | 8 +- 8 files changed, 738 insertions(+), 704 deletions(-) create mode 100644 src/torchjd/sparse/_aten_function_overrides/__init__.py create mode 100644 src/torchjd/sparse/_aten_function_overrides/backward.py create mode 100644 src/torchjd/sparse/_aten_function_overrides/einsum.py create mode 100644 src/torchjd/sparse/_aten_function_overrides/pointwise.py create mode 100644 src/torchjd/sparse/_aten_function_overrides/shape.py diff --git a/src/torchjd/sparse/__init__.py b/src/torchjd/sparse/__init__.py index 8350adff9..c541cc462 100644 --- a/src/torchjd/sparse/__init__.py +++ b/src/torchjd/sparse/__init__.py @@ -1 +1,3 @@ +# Need to import this to execute the code inside and thus to override the functions +from . import _aten_function_overrides from ._diagonal_sparse_tensor import DiagonalSparseTensor, make_dst diff --git a/src/torchjd/sparse/_aten_function_overrides/__init__.py b/src/torchjd/sparse/_aten_function_overrides/__init__.py new file mode 100644 index 000000000..b33cf8d62 --- /dev/null +++ b/src/torchjd/sparse/_aten_function_overrides/__init__.py @@ -0,0 +1 @@ +from . import backward, einsum, pointwise, shape diff --git a/src/torchjd/sparse/_aten_function_overrides/backward.py b/src/torchjd/sparse/_aten_function_overrides/backward.py new file mode 100644 index 000000000..ed5d283ef --- /dev/null +++ b/src/torchjd/sparse/_aten_function_overrides/backward.py @@ -0,0 +1,36 @@ +from torch import Tensor +from torch.ops import aten # type: ignore + +from torchjd.sparse import DiagonalSparseTensor + + +@DiagonalSparseTensor.implements(aten.threshold_backward.default) +def threshold_backward_default( + grad_output: DiagonalSparseTensor, self: Tensor, threshold +) -> DiagonalSparseTensor: + new_physical = aten.threshold_backward.default(grad_output.physical, self, threshold) + + return DiagonalSparseTensor(new_physical, grad_output.v_to_ps) + + +@DiagonalSparseTensor.implements(aten.hardtanh_backward.default) +def hardtanh_backward_default( + grad_output: DiagonalSparseTensor, + self: Tensor, + min_val: Tensor | int | float, + max_val: Tensor | int | float, +) -> DiagonalSparseTensor: + if isinstance(self, DiagonalSparseTensor): + raise NotImplementedError() + + new_physical = aten.hardtanh_backward.default(grad_output.physical, self, min_val, max_val) + return DiagonalSparseTensor(new_physical, grad_output.v_to_ps) + + +@DiagonalSparseTensor.implements(aten.hardswish_backward.default) +def hardswish_backward_default(grad_output: DiagonalSparseTensor, self: Tensor): + if isinstance(self, DiagonalSparseTensor): + raise NotImplementedError() + + new_physical = aten.hardswish_backward.default(grad_output.physical, self) + return DiagonalSparseTensor(new_physical, grad_output.v_to_ps) diff --git a/src/torchjd/sparse/_aten_function_overrides/einsum.py b/src/torchjd/sparse/_aten_function_overrides/einsum.py new file mode 100644 index 000000000..0a6a2a318 --- /dev/null +++ b/src/torchjd/sparse/_aten_function_overrides/einsum.py @@ -0,0 +1,246 @@ +import torch +from torch import Tensor, tensor +from torch.ops import aten # type: ignore + +from torchjd.sparse import DiagonalSparseTensor +from torchjd.sparse._diagonal_sparse_tensor import ( + p_to_vs_from_v_to_ps, + to_diagonal_sparse_tensor, + to_most_efficient_tensor, +) + + +def prepare_for_elementwise_op( + t1: Tensor | int | float, t2: Tensor | int | float +) -> tuple[DiagonalSparseTensor, DiagonalSparseTensor]: + """ + Prepares two DSTs of the same shape from two args, one of those being a DST, and the other being + a DST, Tensor, int or float. + """ + + assert isinstance(t1, DiagonalSparseTensor) or isinstance(t2, DiagonalSparseTensor) + + if isinstance(t1, int) or isinstance(t1, float): + t1_ = tensor(t1, device=t2.device) + else: + t1_ = t1 + + if isinstance(t2, int) or isinstance(t2, float): + t2_ = tensor(t2, device=t1.device) + else: + t2_ = t2 + + t1_, t2_ = aten.broadcast_tensors.default([t1_, t2_]) + t1_ = to_diagonal_sparse_tensor(t1_) + t2_ = to_diagonal_sparse_tensor(t2_) + + return t1_, t2_ + + +@DiagonalSparseTensor.implements(aten.mul.Tensor) +def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: + # Element-wise multiplication with broadcasting + t1_, t2_ = prepare_for_elementwise_op(t1, t2) + all_dims = list(range(t1_.ndim)) + return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) + + +@DiagonalSparseTensor.implements(aten.div.Tensor) +def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: + t1_, t2_ = prepare_for_elementwise_op(t1, t2) + t2_ = DiagonalSparseTensor(1.0 / t2_.physical, t2_.v_to_ps) + all_dims = list(range(t1_.ndim)) + return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) + + +@DiagonalSparseTensor.implements(aten.mul.Scalar) +def mul_Scalar(t: DiagonalSparseTensor, scalar) -> DiagonalSparseTensor: + # TODO: maybe it could be that scalar is a scalar DST and t is a normal tensor. Need to check + # that + + assert isinstance(t, DiagonalSparseTensor) + new_physical = aten.mul.Scalar(t.physical, scalar) + return DiagonalSparseTensor(new_physical, t.v_to_ps) + + +@DiagonalSparseTensor.implements(aten.add.Tensor) +def add_Tensor( + t1: Tensor | int | float, t2: Tensor | int | float, alpha: Tensor | float = 1.0 +) -> DiagonalSparseTensor: + t1_, t2_ = prepare_for_elementwise_op(t1, t2) + + if t1_.v_to_ps == t2_.v_to_ps: + new_physical = t1_.physical + t2_.physical * alpha + return DiagonalSparseTensor(new_physical, t1_.v_to_ps) + else: + raise NotImplementedError() + + +def einsum(*args: tuple[DiagonalSparseTensor, list[int]], output: list[int]) -> Tensor: + + # First part of the algorithm, determine how to cluster physical indices as well as the common + # p_shapes corresponding to matching v_dims. Second part translates to physical einsum. + + # get a map from einsum index to (tensor_idx, v_dims) + # get a map from einsum index to merge of strides corresponding to v_dims with that index + # use to_target_physical_strides on each physical and v_to_ps + # cluster pairs of (einsum_index, new_stride) using new_v_to_ps and possibly its corresponding + # p_to_vs + # get unique indices + # map output indices (there can be splits) + # call physical einsum + # build resulting dst + + # OVER + + # an index in the physical einsum is uniquely characterized by a virtual einsum index and a + # stride corresponding to the physical stride in the virtual one (note that as the virtual shape + # for two virtual index that match should match, then we want to match the strides and reshape + # accordingly). + # We want to cluster such indices whenever several appear in the same p_to_vs + + # TODO: Handle ellipsis + # If we have an index v for some virtual dim whose corresponding v_to_ps is a non-trivial list + # [p_1, ..., p_k], then we have to create fresh sub-indices for each dimension. + # For this reason, an index is decomposed into sub-indices that are then independently + # clustered. + # So if an index i in args for some DiagonalSparseTensor corresponds to a v_to_ps [j, k, l], + # We will consider three indices (i, 0), (i, 1) and (i, 2). + # If furthermore [k] correspond to the v_to_ps of some other tensor with index j, then + # (i, 1) and (j, 0) will be clustered together (and end up being mapped to the same indice in + # the resulting einsum). + # Note that this is a problem if two virtual dimensions (from possibly different + # DiagonaSparseTensors) have the same size but not the same decomposition into physical + # dimension sizes. For now lets leave the responsibility to care about that in the calling + # functions, if we can factor code later on we will. + + index_parents = dict[tuple[int, int], tuple[int, int]]() + + def get_representative(index: tuple[int, int]) -> tuple[int, int]: + if index not in index_parents: + # If an index is not yet in a cluster, put it in its own. + index_parents[index] = index + current = index_parents[index] + if current != index: + # Compress path to representative + index_parents[index] = get_representative(current) + return index_parents[index] + + def group_indices(indices: list[tuple[int, int]]) -> None: + first_representative = get_representative(indices[0]) + for i in indices[1:]: + curr_representative = get_representative(i) + index_parents[curr_representative] = first_representative + + new_indices_pair = list[list[tuple[int, int]]]() + tensors = list[Tensor]() + indices_to_n_pdims = dict[int, int]() + for t, indices in args: + assert isinstance(t, DiagonalSparseTensor) + tensors.append(t.physical) + for ps, index in zip(t.v_to_ps, indices): + if index in indices_to_n_pdims: + assert indices_to_n_pdims[index] == len(ps) + else: + indices_to_n_pdims[index] = len(ps) + p_to_vs = p_to_vs_from_v_to_ps(t.v_to_ps) + for indices_ in p_to_vs: + # elements in indices[indices_] map to the same dimension, they should be clustered + # together + group_indices([(indices[i], sub_i) for i, sub_i in indices_]) + # record the physical dimensions, index[v] for v in vs will end-up mapping to the same + # final dimension as they were just clustered, so we can take the first, which exists as + # t is a valid DST. + new_indices_pair.append([(indices[vs[0][0]], vs[0][1]) for vs in p_to_vs]) + + current = 0 + pair_to_int = dict[tuple[int, int], int]() + + def unique_int(pair: tuple[int, int]) -> int: + nonlocal current + if pair in pair_to_int: + return pair_to_int[pair] + pair_to_int[pair] = current + current += 1 + return pair_to_int[pair] + + new_indices = [ + [unique_int(get_representative(i)) for i in indices] for indices in new_indices_pair + ] + new_output = list[int]() + v_to_ps = list[list[int]]() + for i in output: + current_v_to_ps = [] + for j in range(indices_to_n_pdims[i]): + k = unique_int(get_representative((i, j))) + if k in new_output: + current_v_to_ps.append(new_output.index(k)) + else: + current_v_to_ps.append(len(new_output)) + new_output.append(k) + v_to_ps.append(current_v_to_ps) + + physical = torch.einsum(*[x for y in zip(tensors, new_indices) for x in y], new_output) + # Need to use the safe constructor, otherwise the dimensions may not be maximally grouped. + # Maybe there is a way to fix that though. + return to_most_efficient_tensor(physical, v_to_ps) + + +@DiagonalSparseTensor.implements(aten.bmm.default) +def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor: + assert isinstance(mat1, DiagonalSparseTensor) or isinstance(mat2, DiagonalSparseTensor) + assert ( + mat1.ndim == 3 + and mat2.ndim == 3 + and mat1.shape[0] == mat2.shape[0] + and mat1.shape[2] == mat2.shape[1] + ) + + mat1_ = to_diagonal_sparse_tensor(mat1) + mat2_ = to_diagonal_sparse_tensor(mat2) + + # TODO: Verify that the dimension `0` of mat1_ and mat2_ have the same physical dimension sizes + # decompositions. If not, can reshape to common decomposition? + return einsum((mat1_, [0, 1, 2]), (mat2_, [0, 2, 3]), output=[0, 1, 3]) + + +@DiagonalSparseTensor.implements(aten.mm.default) +def mm_default(mat1: Tensor, mat2: Tensor) -> Tensor: + assert isinstance(mat1, DiagonalSparseTensor) or isinstance(mat2, DiagonalSparseTensor) + assert mat1.ndim == 2 and mat2.ndim == 2 and mat1.shape[1] == mat2.shape[0] + + mat1_ = to_diagonal_sparse_tensor(mat1) + mat2_ = to_diagonal_sparse_tensor(mat2) + + return einsum((mat1_, [0, 1]), (mat2_, [1, 2]), output=[0, 2]) + + +@DiagonalSparseTensor.implements(aten.mean.default) +def mean_default(t: DiagonalSparseTensor) -> Tensor: + assert isinstance(t, DiagonalSparseTensor) + return aten.sum.default(t.physical) / t.numel() + + +@DiagonalSparseTensor.implements(aten.sum.default) +def sum_default(t: DiagonalSparseTensor) -> Tensor: + assert isinstance(t, DiagonalSparseTensor) + return aten.sum.default(t.physical) + + +@DiagonalSparseTensor.implements(aten.sum.dim_IntList) +def sum_dim_IntList( + t: DiagonalSparseTensor, dim: list[int], keepdim: bool = False, dtype=None +) -> Tensor: + assert isinstance(t, DiagonalSparseTensor) + + if dtype: + raise NotImplementedError() + + all_dims = list(range(t.ndim)) + result = einsum((t, all_dims), output=[d for d in all_dims if d not in dim]) + + if keepdim: + for d in dim: + result = result.unsqueeze(d) + + return result diff --git a/src/torchjd/sparse/_aten_function_overrides/pointwise.py b/src/torchjd/sparse/_aten_function_overrides/pointwise.py new file mode 100644 index 000000000..8e0e89f96 --- /dev/null +++ b/src/torchjd/sparse/_aten_function_overrides/pointwise.py @@ -0,0 +1,125 @@ +from torch.ops import aten # type: ignore + +from torchjd.sparse import DiagonalSparseTensor + +# pointwise functions applied to one Tensor with `0.0 → 0` +_POINTWISE_FUNCTIONS = [ + aten.abs.default, + aten.absolute.default, + aten.asin.default, + aten.asinh.default, + aten.atan.default, + aten.atanh.default, + aten.ceil.default, + aten.erf.default, + aten.erfinv.default, + aten.expm1.default, + aten.fix.default, + aten.floor.default, + aten.hardtanh.default, + aten.leaky_relu.default, + aten.log1p.default, + aten.neg.default, + aten.negative.default, + aten.positive.default, + aten.relu.default, + aten.round.default, + aten.sgn.default, + aten.sign.default, + aten.sin.default, + aten.sinh.default, + aten.sqrt.default, + aten.square.default, + aten.tan.default, + aten.tanh.default, + aten.trunc.default, +] + +_IN_PLACE_POINTWISE_FUNCTIONS = [ + aten.abs_.default, + aten.absolute_.default, + aten.asin_.default, + aten.asinh_.default, + aten.atan_.default, + aten.atanh_.default, + aten.ceil_.default, + aten.erf_.default, + aten.erfinv_.default, + aten.expm1_.default, + aten.fix_.default, + aten.floor_.default, + aten.hardtanh_.default, + aten.leaky_relu_.default, + aten.log1p_.default, + aten.neg_.default, + aten.negative_.default, + aten.relu_.default, + aten.round_.default, + aten.sgn_.default, + aten.sign_.default, + aten.sin_.default, + aten.sinh_.default, + aten.sqrt_.default, + aten.square_.default, + aten.tan_.default, + aten.tanh_.default, + aten.trunc_.default, +] + + +def _override_pointwise(op): + @DiagonalSparseTensor.implements(op) + def func_(t: DiagonalSparseTensor) -> DiagonalSparseTensor: + assert isinstance(t, DiagonalSparseTensor) + return DiagonalSparseTensor(op(t.physical), t.v_to_ps) + + return func_ + + +def _override_inplace_pointwise(op): + @DiagonalSparseTensor.implements(op) + def func_(t: DiagonalSparseTensor) -> DiagonalSparseTensor: + assert isinstance(t, DiagonalSparseTensor) + op(t.physical) + return t + + +for pointwise_func in _POINTWISE_FUNCTIONS: + _override_pointwise(pointwise_func) + +for pointwise_func in _IN_PLACE_POINTWISE_FUNCTIONS: + _override_inplace_pointwise(pointwise_func) + + +@DiagonalSparseTensor.implements(aten.pow.Tensor_Scalar) +def pow_Tensor_Scalar(t: DiagonalSparseTensor, exponent: float) -> DiagonalSparseTensor: + assert isinstance(t, DiagonalSparseTensor) + + if exponent <= 0.0: + # Need to densify because we don't have pow(0.0, exponent) = 0.0 + return aten.pow.Tensor_Scalar(t.to_dense(), exponent) + + new_physical = aten.pow.Tensor_Scalar(t.physical, exponent) + return DiagonalSparseTensor(new_physical, t.v_to_ps) + + +# Somehow there's no pow_.Tensor_Scalar and pow_.Scalar takes tensor and scalar. +@DiagonalSparseTensor.implements(aten.pow_.Scalar) +def pow__Scalar(t: DiagonalSparseTensor, exponent: float) -> DiagonalSparseTensor: + assert isinstance(t, DiagonalSparseTensor) + + if exponent <= 0.0: + # Need to densify because we don't have pow(0.0, exponent) = 0.0 + # Note sure if it's even possible to densify in-place, so let's just raise an error. + raise ValueError(f"in-place pow with an exponent of {exponent} (<= 0) is not supported.") + + aten.pow_.Scalar(t.physical, exponent) + return t + + +@DiagonalSparseTensor.implements(aten.div.Scalar) +def div_Scalar(t: DiagonalSparseTensor, divisor: float) -> DiagonalSparseTensor: + assert isinstance(t, DiagonalSparseTensor) + + new_physical = aten.div.Scalar(t.physical, divisor) + return DiagonalSparseTensor(new_physical, t.v_to_ps) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py new file mode 100644 index 000000000..7e9939770 --- /dev/null +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -0,0 +1,323 @@ +from math import prod +from typing import cast + +import torch +from torch import Tensor, tensor +from torch.ops import aten # type: ignore + +from torchjd.sparse import DiagonalSparseTensor +from torchjd.sparse._diagonal_sparse_tensor import ( + encode_v_to_ps, + fix_dim_encoding, + print_fallback, + to_most_efficient_tensor, + unwrap_to_dense, +) + + +@DiagonalSparseTensor.implements(aten.view.default) +def view_default(t: DiagonalSparseTensor, shape: list[int]) -> Tensor: + assert isinstance(t, DiagonalSparseTensor) + + shape = infer_shape(shape, t.numel()) + + if prod(shape) != t.numel(): + raise ValueError(f"shape '{shape}' is invalid for input of size {t.numel()}") + + new_v_to_ps = [] + idx = 0 + flat_v_to_ps = [dim for dims in t.v_to_ps for dim in dims] + new_physical = t.physical + for s in shape: + group = [] + current_size = 1 + + while current_size < s: + if idx >= len(flat_v_to_ps): + # TODO: I don't think this can happen, need to review and remove if I'm right. + raise ValueError() + + pdim = flat_v_to_ps[idx] + pdim_size = new_physical.shape[pdim] + + if current_size * pdim_size > s: + # Need to split physical dimension + if s % current_size != 0: + raise ValueError("Can't split physical dimension") + + new_pdim_first_dim_size = s // current_size + + if pdim_size % new_pdim_first_dim_size != 0: + raise ValueError("Can't split physical dimension") + + new_pdim_shape = [new_pdim_first_dim_size, pdim_size // new_pdim_first_dim_size] + new_physical, new_encoding = unsquash_pdim(new_physical, pdim, new_pdim_shape) + + new_v_to_ps = [ + [new_d for d in dims for new_d in new_encoding[d]] for dims in new_v_to_ps + ] + # A bit of a weird trick here. We want to re-encode flat_v_to_ps according to + # new_encoding. However, re-encoding elements before idx would potentially change + # the length of the list before idx, so idx would not have the right value anymore. + # Since we don't need the elements of flat_v_to_ps that are before idx anyway, we + # just get rid of them and re-encode flat_v_to_ps[idx:] instead, and reset idx to 0 + # to say that we're back at the beginning of this new list. + flat_v_to_ps = [new_d for d in flat_v_to_ps[idx:] for new_d in new_encoding[d]] + idx = 0 + + group.append(pdim) + current_size *= new_physical.shape[pdim] + idx += 1 + + new_v_to_ps.append(group) + + if idx != len(flat_v_to_ps): + raise ValueError(f"idx != len(flat_v_to_ps). {idx}; {flat_v_to_ps}; {shape}; {t.v_to_ps}") + + # The above code does not handle physical dimension squashing, so the physical is not + # necessarily maximally squashed at this point, so we need the safe constructor. + return to_most_efficient_tensor(new_physical, new_v_to_ps) + + +def infer_shape(shape: list[int], numel: int) -> list[int]: + if shape.count(-1) > 1: + raise ValueError("Only one dimension can be inferred") + known = 1 + for s in shape: + if s != -1: + known *= s + inferred = numel // known + return [inferred if s == -1 else s for s in shape] + + +def unsquash_pdim_from_strides( + physical: Tensor, pdim: int, new_pdim_shape: list[int] +) -> tuple[Tensor, Tensor]: + new_shape = list(physical.shape) + new_shape = new_shape[:pdim] + new_pdim_shape + new_shape[pdim + 1 :] + new_physical = physical.reshape(new_shape) + + stride_multipliers = tensor([prod(new_pdim_shape[i + 1 :]) for i in range(len(new_pdim_shape))]) + return new_physical, stride_multipliers + + +def unsquash_pdim( + physical: Tensor, pdim: int, new_pdim_shape: list[int] +) -> tuple[Tensor, list[list[int]]]: + new_shape = list(physical.shape) + new_shape = new_shape[:pdim] + new_pdim_shape + new_shape[pdim + 1 :] + new_physical = physical.reshape(new_shape) + + def new_encoding_fn(d: int) -> list[int]: + if d < pdim: + return [d] + elif d > pdim: + return [d + len(new_pdim_shape) - 1] + else: + return [pdim + i for i in range(len(new_pdim_shape))] + + new_encoding = [new_encoding_fn(d) for d in range(len(physical.shape))] + return new_physical, new_encoding + + +@DiagonalSparseTensor.implements(aten._unsafe_view.default) +def _unsafe_view_default(t: DiagonalSparseTensor, shape: list[int]) -> Tensor: + return view_default( + t, shape + ) # We don't do the optimizations that they do in https://github.com/pytorch/pytorch/blame/main/aten/src/ATen/native/TensorShape.cpp + + +@DiagonalSparseTensor.implements(aten.unsqueeze.default) +def unsqueeze_default(t: DiagonalSparseTensor, dim: int) -> DiagonalSparseTensor: + assert isinstance(t, DiagonalSparseTensor) + assert -t.ndim - 1 <= dim < t.ndim + 1 + + if dim < 0: + dim = t.ndim + dim + 1 + + new_v_to_ps = [p for p in t.v_to_ps] # Deepcopy the list to not modify the original v_to_ps + new_v_to_ps.insert(dim, []) + + return DiagonalSparseTensor(t.physical, new_v_to_ps) + + +@DiagonalSparseTensor.implements(aten.squeeze.dims) +def squeeze_dims(t: DiagonalSparseTensor, dims: list[int] | int | None) -> Tensor: + assert isinstance(t, DiagonalSparseTensor) + + if dims is None: + excluded = set(range(t.ndim)) + elif isinstance(dims, int): + excluded = {dims} + else: + excluded = set(dims) + + new_v_to_ps = [pdims for i, pdims in enumerate(t.v_to_ps) if i not in excluded] + + return to_most_efficient_tensor(t.physical, new_v_to_ps) + + +@DiagonalSparseTensor.implements(aten.permute.default) +def permute_default(t: DiagonalSparseTensor, dims: list[int]) -> DiagonalSparseTensor: + new_v_to_ps = [t.v_to_ps[d] for d in dims] + + new_physical, new_v_to_ps = fix_dim_encoding(t.physical, new_v_to_ps) + return DiagonalSparseTensor(new_physical, new_v_to_ps) + + +@DiagonalSparseTensor.implements(aten.cat.default) +def cat_default(tensors: list[Tensor], dim: int) -> Tensor: + if any(not isinstance(t, DiagonalSparseTensor) for t in tensors): + print_fallback(aten.cat.default, (tensors, dim), {}) + return aten.cat.default([unwrap_to_dense(t) for t in tensors]) + + tensors_ = [cast(DiagonalSparseTensor, t) for t in tensors] + ref_tensor = tensors_[0] + ref_strides = ref_tensor.strides + if any(not torch.equal(t.strides, ref_strides) for t in tensors_[1:]): + raise NotImplementedError() + + # We need to try to find the (pretty sure it either does not exist or is unique) physical + # dimension that makes us only move on virtual dimension dim. It also needs to be such that + # traversing it entirely brings us exactly to the end of virtual dimension dim. + + ref_virtual_dim_size = ref_tensor.shape[dim] + indices = torch.argwhere( + torch.eq(ref_strides[dim] * tensor(ref_tensor.physical.shape), ref_virtual_dim_size) + & torch.eq(ref_strides.sum(dim=0) * tensor(ref_tensor.physical.shape), ref_virtual_dim_size) + ) + assert len(indices) <= 1 + + if len(indices) == 0: + # Add a physical dimension pdim on which we can concatenate the physicals such that this + # translates into a concatenation of the virtuals on virtual dimension dim. + + # Stride-based representation: + # new_stride_column = torch.zeros(ref_tensor.ndim, dtype=torch.int) + # new_stride_column[dim] = ref_virtual_dim_size + + pdim = ref_tensor.physical.ndim + new_v_to_ps = [[d for d in pdims] for pdims in ref_tensor.v_to_ps] + new_v_to_ps[dim] = [pdim] + new_v_to_ps[dim] + new_v_to_ps, destination = encode_v_to_ps(new_v_to_ps) + source = list(range(len(destination))) + physicals = [t.physical.unsqueeze(-1).movedim(source, destination) for t in tensors_] + else: + # Such a physical dimension already exists. Note that an alternative implementation would be + # to simply always add the physical dimension, and squash it if it ends up being not needed. + physicals = [t.physical for t in tensors_] + pdim = indices[0][0] + new_v_to_ps = ref_tensor.v_to_ps + + new_physical = aten.cat.default(physicals, dim=pdim) + return DiagonalSparseTensor(new_physical, new_v_to_ps) + + +@DiagonalSparseTensor.implements(aten.expand.default) +def expand_default(t: DiagonalSparseTensor, sizes: list[int]) -> DiagonalSparseTensor: + # note that sizes could also be just an int, or a torch.Size i think + assert isinstance(t, DiagonalSparseTensor) + assert isinstance(sizes, list) + assert len(sizes) >= t.ndim + + for _ in range(len(sizes) - t.ndim): + t = t.unsqueeze(0) + + assert len(sizes) == t.ndim + + new_physical = t.physical + new_v_to_ps = t.v_to_ps + n_added_physical_dims = 0 + for dim, (ps, orig_size, new_size) in enumerate(zip(t.v_to_ps, t.shape, sizes, strict=True)): + if len(ps) > 0 and orig_size != new_size and new_size != -1: + raise ValueError( + f"Cannot expand dim {dim} of size != 1. Found size {orig_size} and target size " + f"{new_size}." + ) + + if len(ps) == 0 and new_size != 1 and new_size != -1: + # Add a dimension of size new_size at the end of the physical tensor. + new_physical_shape = list(new_physical.shape) + [new_size] + new_physical = new_physical.unsqueeze(-1).expand(new_physical_shape) + new_v_to_ps[dim] = [t.physical.ndim + n_added_physical_dims] + n_added_physical_dims += 1 + + new_v_to_ps, destination = encode_v_to_ps(new_v_to_ps) + new_physical = new_physical.movedim(list(range(len(destination))), destination) + + return DiagonalSparseTensor(new_physical, new_v_to_ps) + + +@DiagonalSparseTensor.implements(aten.broadcast_tensors.default) +def broadcast_tensors_default(tensors: list[Tensor]) -> tuple[Tensor, Tensor]: + if len(tensors) != 2: + raise NotImplementedError() + + t1, t2 = tensors + + if t1.shape == t2.shape: + return t1, t2 + + a = t1 if t1.ndim >= t2.ndim else t2 + b = t2 if t1.ndim >= t2.ndim else t1 + + a_shape = list(a.shape) + padded_b_shape = [1] * (a.ndim - b.ndim) + list(b.shape) + + new_shape = list[int]() + + for s_a, s_b in zip(a_shape, padded_b_shape): + if s_a != 1 and s_b != 1 and s_a != s_b: + raise ValueError("Incompatible shapes for broadcasting") + else: + new_shape.append(max(s_a, s_b)) + + return aten.expand.default(t1, new_shape), aten.expand.default(t2, new_shape) + + +@DiagonalSparseTensor.implements(aten.slice.Tensor) +def slice_Tensor( + t: DiagonalSparseTensor, dim: int, start: int | None, end: int | None, step: int = 1 +) -> DiagonalSparseTensor: + assert isinstance(t, DiagonalSparseTensor) + + physical_dims = t.v_to_ps[dim] + + if len(physical_dims) > 1: + raise ValueError( + "Cannot yet slice virtual dim corresponding to several physical dims.\n" + f"{t.debug_info()}\n" + f"dim={dim}, start={start}, end={end}, step={step}." + ) + elif len(physical_dims) == 0: + # Trying to slice a virtual dim of size 1. + # Either + # - the element of this dim is included in the slice: keep it as it is + # - it's not included in the slice (e.g. end<=start): we would end up with a size of 0 on + # that dimension, so we'd need to add a dimension of size 0 to the physical. This is not + # implemented yet. + start_ = start if start is not None else 0 + end_ = end if end is not None else 1 + if end_ <= start_: # TODO: the condition might be a bit more complex if step != 1 + raise NotImplementedError( + "Slicing of dimension of size 1 leading to dimension of size 0 not implemented yet." + ) + else: + new_physical = t.physical + else: + physical_dim = physical_dims[0] + new_physical = aten.slice.Tensor(t.physical, physical_dim, start, end, step) + + return DiagonalSparseTensor(new_physical, t.v_to_ps) + + +@DiagonalSparseTensor.implements(aten.transpose.int) +def transpose_int(t: DiagonalSparseTensor, dim0: int, dim1: int) -> DiagonalSparseTensor: + assert isinstance(t, DiagonalSparseTensor) + + new_v_to_ps = [dims for dims in t.v_to_ps] + new_v_to_ps[dim0] = t.v_to_ps[dim1] + new_v_to_ps[dim1] = t.v_to_ps[dim0] + + new_physical, new_v_to_ps = fix_dim_encoding(t.physical, new_v_to_ps) + return DiagonalSparseTensor(new_physical, new_v_to_ps) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_diagonal_sparse_tensor.py index a09798648..a051e5c13 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_diagonal_sparse_tensor.py @@ -3,11 +3,9 @@ from functools import wraps from itertools import accumulate from math import prod -from typing import cast import torch from torch import Tensor, arange, meshgrid, stack, tensor, tensordot, zeros -from torch.ops import aten # type: ignore from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten @@ -421,702 +419,3 @@ def make_dst(physical: Tensor, v_to_ps: list[list[int]]) -> DiagonalSparseTensor physical, v_to_ps = fix_dim_of_size_1(physical, v_to_ps) physical, v_to_ps = fix_ungrouped_dims(physical, v_to_ps) return DiagonalSparseTensor(physical, v_to_ps) - - -@DiagonalSparseTensor.implements(aten.mean.default) -def mean_default(t: DiagonalSparseTensor) -> Tensor: - assert isinstance(t, DiagonalSparseTensor) - return aten.sum.default(t.physical) / t.numel() - - -@DiagonalSparseTensor.implements(aten.sum.default) -def sum_default(t: DiagonalSparseTensor) -> Tensor: - assert isinstance(t, DiagonalSparseTensor) - return aten.sum.default(t.physical) - - -@DiagonalSparseTensor.implements(aten.sum.dim_IntList) -def sum_dim_IntList( - t: DiagonalSparseTensor, dim: list[int], keepdim: bool = False, dtype=None -) -> Tensor: - assert isinstance(t, DiagonalSparseTensor) - - if dtype: - raise NotImplementedError() - - all_dims = list(range(t.ndim)) - result = einsum((t, all_dims), output=[d for d in all_dims if d not in dim]) - - if keepdim: - for d in dim: - result = result.unsqueeze(d) - - return result - - -@DiagonalSparseTensor.implements(aten.pow.Tensor_Scalar) -def pow_Tensor_Scalar(t: DiagonalSparseTensor, exponent: float) -> DiagonalSparseTensor: - assert isinstance(t, DiagonalSparseTensor) - - if exponent <= 0.0: - # Need to densify because we don't have pow(0.0, exponent) = 0.0 - return aten.pow.Tensor_Scalar(t.to_dense(), exponent) - - new_physical = aten.pow.Tensor_Scalar(t.physical, exponent) - return DiagonalSparseTensor(new_physical, t.v_to_ps) - - -# Somehow there's no pow_.Tensor_Scalar and pow_.Scalar takes tensor and scalar. -@DiagonalSparseTensor.implements(aten.pow_.Scalar) -def pow__Scalar(t: DiagonalSparseTensor, exponent: float) -> DiagonalSparseTensor: - assert isinstance(t, DiagonalSparseTensor) - - if exponent <= 0.0: - # Need to densify because we don't have pow(0.0, exponent) = 0.0 - # Note sure if it's even possible to densify in-place, so let's just raise an error. - raise ValueError(f"in-place pow with an exponent of {exponent} (<= 0) is not supported.") - - aten.pow_.Scalar(t.physical, exponent) - return t - - -@DiagonalSparseTensor.implements(aten.unsqueeze.default) -def unsqueeze_default(t: DiagonalSparseTensor, dim: int) -> DiagonalSparseTensor: - assert isinstance(t, DiagonalSparseTensor) - assert -t.ndim - 1 <= dim < t.ndim + 1 - - if dim < 0: - dim = t.ndim + dim + 1 - - new_v_to_ps = [p for p in t.v_to_ps] # Deepcopy the list to not modify the original v_to_ps - new_v_to_ps.insert(dim, []) - - return DiagonalSparseTensor(t.physical, new_v_to_ps) - - -@DiagonalSparseTensor.implements(aten.squeeze.dims) -def squeeze_dims(t: DiagonalSparseTensor, dims: list[int] | int | None) -> Tensor: - assert isinstance(t, DiagonalSparseTensor) - - if dims is None: - excluded = set(range(t.ndim)) - elif isinstance(dims, int): - excluded = {dims} - else: - excluded = set(dims) - - new_v_to_ps = [pdims for i, pdims in enumerate(t.v_to_ps) if i not in excluded] - - return to_most_efficient_tensor(t.physical, new_v_to_ps) - - -@DiagonalSparseTensor.implements(aten.permute.default) -def permute_default(t: DiagonalSparseTensor, dims: list[int]) -> DiagonalSparseTensor: - new_v_to_ps = [t.v_to_ps[d] for d in dims] - - new_physical, new_v_to_ps = fix_dim_encoding(t.physical, new_v_to_ps) - return DiagonalSparseTensor(new_physical, new_v_to_ps) - - -@DiagonalSparseTensor.implements(aten.cat.default) -def cat_default(tensors: list[Tensor], dim: int) -> Tensor: - if any(not isinstance(t, DiagonalSparseTensor) for t in tensors): - print_fallback(aten.cat.default, (tensors, dim), {}) - return aten.cat.default([unwrap_to_dense(t) for t in tensors]) - - tensors_ = [cast(DiagonalSparseTensor, t) for t in tensors] - ref_tensor = tensors_[0] - ref_strides = ref_tensor.strides - if any(not torch.equal(t.strides, ref_strides) for t in tensors_[1:]): - raise NotImplementedError() - - # We need to try to find the (pretty sure it either does not exist or is unique) physical - # dimension that makes us only move on virtual dimension dim. It also needs to be such that - # traversing it entirely brings us exactly to the end of virtual dimension dim. - - ref_virtual_dim_size = ref_tensor.shape[dim] - indices = torch.argwhere( - torch.eq(ref_strides[dim] * tensor(ref_tensor.physical.shape), ref_virtual_dim_size) - & torch.eq(ref_strides.sum(dim=0) * tensor(ref_tensor.physical.shape), ref_virtual_dim_size) - ) - assert len(indices) <= 1 - - if len(indices) == 0: - # Add a physical dimension pdim on which we can concatenate the physicals such that this - # translates into a concatenation of the virtuals on virtual dimension dim. - - # Stride-based representation: - # new_stride_column = torch.zeros(ref_tensor.ndim, dtype=torch.int) - # new_stride_column[dim] = ref_virtual_dim_size - - pdim = ref_tensor.physical.ndim - new_v_to_ps = [[d for d in pdims] for pdims in ref_tensor.v_to_ps] - new_v_to_ps[dim] = [pdim] + new_v_to_ps[dim] - new_v_to_ps, destination = encode_v_to_ps(new_v_to_ps) - source = list(range(len(destination))) - physicals = [t.physical.unsqueeze(-1).movedim(source, destination) for t in tensors_] - else: - # Such a physical dimension already exists. Note that an alternative implementation would be - # to simply always add the physical dimension, and squash it if it ends up being not needed. - physicals = [t.physical for t in tensors_] - pdim = indices[0][0] - new_v_to_ps = ref_tensor.v_to_ps - - new_physical = aten.cat.default(physicals, dim=pdim) - return DiagonalSparseTensor(new_physical, new_v_to_ps) - - -def unsquash_pdim_from_strides( - physical: Tensor, pdim: int, new_pdim_shape: list[int] -) -> tuple[Tensor, Tensor]: - new_shape = list(physical.shape) - new_shape = new_shape[:pdim] + new_pdim_shape + new_shape[pdim + 1 :] - new_physical = physical.reshape(new_shape) - - stride_multipliers = tensor([prod(new_pdim_shape[i + 1 :]) for i in range(len(new_pdim_shape))]) - return new_physical, stride_multipliers - - -def unsquash_pdim( - physical: Tensor, pdim: int, new_pdim_shape: list[int] -) -> tuple[Tensor, list[list[int]]]: - new_shape = list(physical.shape) - new_shape = new_shape[:pdim] + new_pdim_shape + new_shape[pdim + 1 :] - new_physical = physical.reshape(new_shape) - - def new_encoding_fn(d: int) -> list[int]: - if d < pdim: - return [d] - elif d > pdim: - return [d + len(new_pdim_shape) - 1] - else: - return [pdim + i for i in range(len(new_pdim_shape))] - - new_encoding = [new_encoding_fn(d) for d in range(len(physical.shape))] - return new_physical, new_encoding - - -def infer_shape(shape: list[int], numel: int) -> list[int]: - if shape.count(-1) > 1: - raise ValueError("Only one dimension can be inferred") - known = 1 - for s in shape: - if s != -1: - known *= s - inferred = numel // known - return [inferred if s == -1 else s for s in shape] - - -@DiagonalSparseTensor.implements(aten.view.default) -def view_default(t: DiagonalSparseTensor, shape: list[int]) -> Tensor: - assert isinstance(t, DiagonalSparseTensor) - - shape = infer_shape(shape, t.numel()) - - if prod(shape) != t.numel(): - raise ValueError(f"shape '{shape}' is invalid for input of size {t.numel()}") - - new_v_to_ps = [] - idx = 0 - flat_v_to_ps = [dim for dims in t.v_to_ps for dim in dims] - new_physical = t.physical - for s in shape: - group = [] - current_size = 1 - - while current_size < s: - if idx >= len(flat_v_to_ps): - # TODO: I don't think this can happen, need to review and remove if I'm right. - raise ValueError() - - pdim = flat_v_to_ps[idx] - pdim_size = new_physical.shape[pdim] - - if current_size * pdim_size > s: - # Need to split physical dimension - if s % current_size != 0: - raise ValueError("Can't split physical dimension") - - new_pdim_first_dim_size = s // current_size - - if pdim_size % new_pdim_first_dim_size != 0: - raise ValueError("Can't split physical dimension") - - new_pdim_shape = [new_pdim_first_dim_size, pdim_size // new_pdim_first_dim_size] - new_physical, new_encoding = unsquash_pdim(new_physical, pdim, new_pdim_shape) - - new_v_to_ps = [ - [new_d for d in dims for new_d in new_encoding[d]] for dims in new_v_to_ps - ] - # A bit of a weird trick here. We want to re-encode flat_v_to_ps according to - # new_encoding. However, re-encoding elements before idx would potentially change - # the length of the list before idx, so idx would not have the right value anymore. - # Since we don't need the elements of flat_v_to_ps that are before idx anyway, we - # just get rid of them and re-encode flat_v_to_ps[idx:] instead, and reset idx to 0 - # to say that we're back at the beginning of this new list. - flat_v_to_ps = [new_d for d in flat_v_to_ps[idx:] for new_d in new_encoding[d]] - idx = 0 - - group.append(pdim) - current_size *= new_physical.shape[pdim] - idx += 1 - - new_v_to_ps.append(group) - - if idx != len(flat_v_to_ps): - raise ValueError(f"idx != len(flat_v_to_ps). {idx}; {flat_v_to_ps}; {shape}; {t.v_to_ps}") - - # The above code does not handle physical dimension squashing, so the physical is not - # necessarily maximally squashed at this point, so we need the safe constructor. - return to_most_efficient_tensor(new_physical, new_v_to_ps) - - -@DiagonalSparseTensor.implements(aten._unsafe_view.default) -def _unsafe_view_default(t: DiagonalSparseTensor, shape: list[int]) -> Tensor: - return view_default( - t, shape - ) # We don't do the optimizations that they do in https://github.com/pytorch/pytorch/blame/main/aten/src/ATen/native/TensorShape.cpp - - -@DiagonalSparseTensor.implements(aten.expand.default) -def expand_default(t: DiagonalSparseTensor, sizes: list[int]) -> DiagonalSparseTensor: - # note that sizes could also be just an int, or a torch.Size i think - assert isinstance(t, DiagonalSparseTensor) - assert isinstance(sizes, list) - assert len(sizes) >= t.ndim - - for _ in range(len(sizes) - t.ndim): - t = t.unsqueeze(0) - - assert len(sizes) == t.ndim - - new_physical = t.physical - new_v_to_ps = t.v_to_ps - n_added_physical_dims = 0 - for dim, (ps, orig_size, new_size) in enumerate(zip(t.v_to_ps, t.shape, sizes, strict=True)): - if len(ps) > 0 and orig_size != new_size and new_size != -1: - raise ValueError( - f"Cannot expand dim {dim} of size != 1. Found size {orig_size} and target size " - f"{new_size}." - ) - - if len(ps) == 0 and new_size != 1 and new_size != -1: - # Add a dimension of size new_size at the end of the physical tensor. - new_physical_shape = list(new_physical.shape) + [new_size] - new_physical = new_physical.unsqueeze(-1).expand(new_physical_shape) - new_v_to_ps[dim] = [t.physical.ndim + n_added_physical_dims] - n_added_physical_dims += 1 - - new_v_to_ps, destination = encode_v_to_ps(new_v_to_ps) - new_physical = new_physical.movedim(list(range(len(destination))), destination) - - return DiagonalSparseTensor(new_physical, new_v_to_ps) - - -@DiagonalSparseTensor.implements(aten.broadcast_tensors.default) -def broadcast_tensors_default(tensors: list[Tensor]) -> tuple[Tensor, Tensor]: - if len(tensors) != 2: - raise NotImplementedError() - - t1, t2 = tensors - - if t1.shape == t2.shape: - return t1, t2 - - a = t1 if t1.ndim >= t2.ndim else t2 - b = t2 if t1.ndim >= t2.ndim else t1 - - a_shape = list(a.shape) - padded_b_shape = [1] * (a.ndim - b.ndim) + list(b.shape) - - new_shape = list[int]() - - for s_a, s_b in zip(a_shape, padded_b_shape): - if s_a != 1 and s_b != 1 and s_a != s_b: - raise ValueError("Incompatible shapes for broadcasting") - else: - new_shape.append(max(s_a, s_b)) - - return aten.expand.default(t1, new_shape), aten.expand.default(t2, new_shape) - - -@DiagonalSparseTensor.implements(aten.div.Scalar) -def div_Scalar(t: DiagonalSparseTensor, divisor: float) -> DiagonalSparseTensor: - assert isinstance(t, DiagonalSparseTensor) - - new_physical = aten.div.Scalar(t.physical, divisor) - return DiagonalSparseTensor(new_physical, t.v_to_ps) - - -@DiagonalSparseTensor.implements(aten.threshold_backward.default) -def threshold_backward_default( - grad_output: DiagonalSparseTensor, self: Tensor, threshold -) -> DiagonalSparseTensor: - new_physical = aten.threshold_backward.default(grad_output.physical, self, threshold) - - return DiagonalSparseTensor(new_physical, grad_output.v_to_ps) - - -@DiagonalSparseTensor.implements(aten.hardtanh_backward.default) -def hardtanh_backward_default( - grad_output: DiagonalSparseTensor, - self: Tensor, - min_val: Tensor | int | float, - max_val: Tensor | int | float, -) -> DiagonalSparseTensor: - if isinstance(self, DiagonalSparseTensor): - raise NotImplementedError() - - new_physical = aten.hardtanh_backward.default(grad_output.physical, self, min_val, max_val) - return DiagonalSparseTensor(new_physical, grad_output.v_to_ps) - - -@DiagonalSparseTensor.implements(aten.hardswish_backward.default) -def hardswish_backward_default(grad_output: DiagonalSparseTensor, self: Tensor): - if isinstance(self, DiagonalSparseTensor): - raise NotImplementedError() - - new_physical = aten.hardswish_backward.default(grad_output.physical, self) - return DiagonalSparseTensor(new_physical, grad_output.v_to_ps) - - -@DiagonalSparseTensor.implements(aten.slice.Tensor) -def slice_Tensor( - t: DiagonalSparseTensor, dim: int, start: int | None, end: int | None, step: int = 1 -) -> DiagonalSparseTensor: - assert isinstance(t, DiagonalSparseTensor) - - physical_dims = t.v_to_ps[dim] - - if len(physical_dims) > 1: - raise ValueError( - "Cannot yet slice virtual dim corresponding to several physical dims.\n" - f"{t.debug_info()}\n" - f"dim={dim}, start={start}, end={end}, step={step}." - ) - elif len(physical_dims) == 0: - # Trying to slice a virtual dim of size 1. - # Either - # - the element of this dim is included in the slice: keep it as it is - # - it's not included in the slice (e.g. end<=start): we would end up with a size of 0 on - # that dimension, so we'd need to add a dimension of size 0 to the physical. This is not - # implemented yet. - start_ = start if start is not None else 0 - end_ = end if end is not None else 1 - if end_ <= start_: # TODO: the condition might be a bit more complex if step != 1 - raise NotImplementedError( - "Slicing of dimension of size 1 leading to dimension of size 0 not implemented yet." - ) - else: - new_physical = t.physical - else: - physical_dim = physical_dims[0] - new_physical = aten.slice.Tensor(t.physical, physical_dim, start, end, step) - - return DiagonalSparseTensor(new_physical, t.v_to_ps) - - -def prepare_for_elementwise_op( - t1: Tensor | int | float, t2: Tensor | int | float -) -> tuple[DiagonalSparseTensor, DiagonalSparseTensor]: - """ - Prepares two DSTs of the same shape from two args, one of those being a DST, and the other being - a DST, Tensor, int or float. - """ - - assert isinstance(t1, DiagonalSparseTensor) or isinstance(t2, DiagonalSparseTensor) - - if isinstance(t1, int) or isinstance(t1, float): - t1_ = tensor(t1, device=t2.device) - else: - t1_ = t1 - - if isinstance(t2, int) or isinstance(t2, float): - t2_ = tensor(t2, device=t1.device) - else: - t2_ = t2 - - t1_, t2_ = aten.broadcast_tensors.default([t1_, t2_]) - t1_ = to_diagonal_sparse_tensor(t1_) - t2_ = to_diagonal_sparse_tensor(t2_) - - return t1_, t2_ - - -@DiagonalSparseTensor.implements(aten.mul.Tensor) -def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: - # Element-wise multiplication with broadcasting - t1_, t2_ = prepare_for_elementwise_op(t1, t2) - all_dims = list(range(t1_.ndim)) - return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) - - -@DiagonalSparseTensor.implements(aten.div.Tensor) -def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: - t1_, t2_ = prepare_for_elementwise_op(t1, t2) - t2_ = DiagonalSparseTensor(1.0 / t2_.physical, t2_.v_to_ps) - all_dims = list(range(t1_.ndim)) - return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) - - -@DiagonalSparseTensor.implements(aten.mul.Scalar) -def mul_Scalar(t: DiagonalSparseTensor, scalar) -> DiagonalSparseTensor: - # TODO: maybe it could be that scalar is a scalar DST and t is a normal tensor. Need to check - # that - - assert isinstance(t, DiagonalSparseTensor) - new_physical = aten.mul.Scalar(t.physical, scalar) - return DiagonalSparseTensor(new_physical, t.v_to_ps) - - -@DiagonalSparseTensor.implements(aten.add.Tensor) -def add_Tensor( - t1: Tensor | int | float, t2: Tensor | int | float, alpha: Tensor | float = 1.0 -) -> DiagonalSparseTensor: - t1_, t2_ = prepare_for_elementwise_op(t1, t2) - - if t1_.v_to_ps == t2_.v_to_ps: - new_physical = t1_.physical + t2_.physical * alpha - return DiagonalSparseTensor(new_physical, t1_.v_to_ps) - else: - raise NotImplementedError() - - -@DiagonalSparseTensor.implements(aten.transpose.int) -def transpose_int(t: DiagonalSparseTensor, dim0: int, dim1: int) -> DiagonalSparseTensor: - assert isinstance(t, DiagonalSparseTensor) - - new_v_to_ps = [dims for dims in t.v_to_ps] - new_v_to_ps[dim0] = t.v_to_ps[dim1] - new_v_to_ps[dim1] = t.v_to_ps[dim0] - - new_physical, new_v_to_ps = fix_dim_encoding(t.physical, new_v_to_ps) - return DiagonalSparseTensor(new_physical, new_v_to_ps) - - -def einsum(*args: tuple[DiagonalSparseTensor, list[int]], output: list[int]) -> Tensor: - - # First part of the algorithm, determine how to cluster physical indices as well as the common - # p_shapes corresponding to matching v_dims. Second part translates to physical einsum. - - # get a map from einsum index to (tensor_idx, v_dims) - # get a map from einsum index to merge of strides corresponding to v_dims with that index - # use to_target_physical_strides on each physical and v_to_ps - # cluster pairs of (einsum_index, new_stride) using new_v_to_ps and possibly its corresponding - # p_to_vs - # get unique indices - # map output indices (there can be splits) - # call physical einsum - # build resulting dst - - # OVER - - # an index in the physical einsum is uniquely characterized by a virtual einsum index and a - # stride corresponding to the physical stride in the virtual one (note that as the virtual shape - # for two virtual index that match should match, then we want to match the strides and reshape - # accordingly). - # We want to cluster such indices whenever several appear in the same p_to_vs - - # TODO: Handle ellipsis - # If we have an index v for some virtual dim whose corresponding v_to_ps is a non-trivial list - # [p_1, ..., p_k], then we have to create fresh sub-indices for each dimension. - # For this reason, an index is decomposed into sub-indices that are then independently - # clustered. - # So if an index i in args for some DiagonalSparseTensor corresponds to a v_to_ps [j, k, l], - # We will consider three indices (i, 0), (i, 1) and (i, 2). - # If furthermore [k] correspond to the v_to_ps of some other tensor with index j, then - # (i, 1) and (j, 0) will be clustered together (and end up being mapped to the same indice in - # the resulting einsum). - # Note that this is a problem if two virtual dimensions (from possibly different - # DiagonaSparseTensors) have the same size but not the same decomposition into physical - # dimension sizes. For now lets leave the responsibility to care about that in the calling - # functions, if we can factor code later on we will. - - index_parents = dict[tuple[int, int], tuple[int, int]]() - - def get_representative(index: tuple[int, int]) -> tuple[int, int]: - if index not in index_parents: - # If an index is not yet in a cluster, put it in its own. - index_parents[index] = index - current = index_parents[index] - if current != index: - # Compress path to representative - index_parents[index] = get_representative(current) - return index_parents[index] - - def group_indices(indices: list[tuple[int, int]]) -> None: - first_representative = get_representative(indices[0]) - for i in indices[1:]: - curr_representative = get_representative(i) - index_parents[curr_representative] = first_representative - - new_indices_pair = list[list[tuple[int, int]]]() - tensors = list[Tensor]() - indices_to_n_pdims = dict[int, int]() - for t, indices in args: - assert isinstance(t, DiagonalSparseTensor) - tensors.append(t.physical) - for ps, index in zip(t.v_to_ps, indices): - if index in indices_to_n_pdims: - assert indices_to_n_pdims[index] == len(ps) - else: - indices_to_n_pdims[index] = len(ps) - p_to_vs = p_to_vs_from_v_to_ps(t.v_to_ps) - for indices_ in p_to_vs: - # elements in indices[indices_] map to the same dimension, they should be clustered - # together - group_indices([(indices[i], sub_i) for i, sub_i in indices_]) - # record the physical dimensions, index[v] for v in vs will end-up mapping to the same - # final dimension as they were just clustered, so we can take the first, which exists as - # t is a valid DST. - new_indices_pair.append([(indices[vs[0][0]], vs[0][1]) for vs in p_to_vs]) - - current = 0 - pair_to_int = dict[tuple[int, int], int]() - - def unique_int(pair: tuple[int, int]) -> int: - nonlocal current - if pair in pair_to_int: - return pair_to_int[pair] - pair_to_int[pair] = current - current += 1 - return pair_to_int[pair] - - new_indices = [ - [unique_int(get_representative(i)) for i in indices] for indices in new_indices_pair - ] - new_output = list[int]() - v_to_ps = list[list[int]]() - for i in output: - current_v_to_ps = [] - for j in range(indices_to_n_pdims[i]): - k = unique_int(get_representative((i, j))) - if k in new_output: - current_v_to_ps.append(new_output.index(k)) - else: - current_v_to_ps.append(len(new_output)) - new_output.append(k) - v_to_ps.append(current_v_to_ps) - - physical = torch.einsum(*[x for y in zip(tensors, new_indices) for x in y], new_output) - # Need to use the safe constructor, otherwise the dimensions may not be maximally grouped. - # Maybe there is a way to fix that though. - return to_most_efficient_tensor(physical, v_to_ps) - - -@DiagonalSparseTensor.implements(aten.bmm.default) -def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor: - assert isinstance(mat1, DiagonalSparseTensor) or isinstance(mat2, DiagonalSparseTensor) - assert ( - mat1.ndim == 3 - and mat2.ndim == 3 - and mat1.shape[0] == mat2.shape[0] - and mat1.shape[2] == mat2.shape[1] - ) - - mat1_ = to_diagonal_sparse_tensor(mat1) - mat2_ = to_diagonal_sparse_tensor(mat2) - - # TODO: Verify that the dimension `0` of mat1_ and mat2_ have the same physical dimension sizes - # decompositions. If not, can reshape to common decomposition? - return einsum((mat1_, [0, 1, 2]), (mat2_, [0, 2, 3]), output=[0, 1, 3]) - - -@DiagonalSparseTensor.implements(aten.mm.default) -def mm_default(mat1: Tensor, mat2: Tensor) -> Tensor: - assert isinstance(mat1, DiagonalSparseTensor) or isinstance(mat2, DiagonalSparseTensor) - assert mat1.ndim == 2 and mat2.ndim == 2 and mat1.shape[1] == mat2.shape[0] - - mat1_ = to_diagonal_sparse_tensor(mat1) - mat2_ = to_diagonal_sparse_tensor(mat2) - - return einsum((mat1_, [0, 1]), (mat2_, [1, 2]), output=[0, 2]) - - -# pointwise functions applied to one Tensor with `0.0 → 0` -_POINTWISE_FUNCTIONS = [ - aten.abs.default, - aten.absolute.default, - aten.asin.default, - aten.asinh.default, - aten.atan.default, - aten.atanh.default, - aten.ceil.default, - aten.erf.default, - aten.erfinv.default, - aten.expm1.default, - aten.fix.default, - aten.floor.default, - aten.hardtanh.default, - aten.leaky_relu.default, - aten.log1p.default, - aten.neg.default, - aten.negative.default, - aten.positive.default, - aten.relu.default, - aten.round.default, - aten.sgn.default, - aten.sign.default, - aten.sin.default, - aten.sinh.default, - aten.sqrt.default, - aten.square.default, - aten.tan.default, - aten.tanh.default, - aten.trunc.default, -] - -_IN_PLACE_POINTWISE_FUNCTIONS = [ - aten.abs_.default, - aten.absolute_.default, - aten.asin_.default, - aten.asinh_.default, - aten.atan_.default, - aten.atanh_.default, - aten.ceil_.default, - aten.erf_.default, - aten.erfinv_.default, - aten.expm1_.default, - aten.fix_.default, - aten.floor_.default, - aten.hardtanh_.default, - aten.leaky_relu_.default, - aten.log1p_.default, - aten.neg_.default, - aten.negative_.default, - aten.relu_.default, - aten.round_.default, - aten.sgn_.default, - aten.sign_.default, - aten.sin_.default, - aten.sinh_.default, - aten.sqrt_.default, - aten.square_.default, - aten.tan_.default, - aten.tanh_.default, - aten.trunc_.default, -] - - -def _override_pointwise(op): - @DiagonalSparseTensor.implements(op) - def func_(t: DiagonalSparseTensor) -> DiagonalSparseTensor: - assert isinstance(t, DiagonalSparseTensor) - return DiagonalSparseTensor(op(t.physical), t.v_to_ps) - - return func_ - - -def _override_inplace_pointwise(op): - @DiagonalSparseTensor.implements(op) - def func_(t: DiagonalSparseTensor) -> DiagonalSparseTensor: - assert isinstance(t, DiagonalSparseTensor) - op(t.physical) - return t - - -for pointwise_func in _POINTWISE_FUNCTIONS: - _override_pointwise(pointwise_func) - -for pointwise_func in _IN_PLACE_POINTWISE_FUNCTIONS: - _override_inplace_pointwise(pointwise_func) diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_diagonal_sparse_tensor.py index e59f7d2c5..05e1c83ba 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_diagonal_sparse_tensor.py @@ -4,15 +4,17 @@ from torch.testing import assert_close from utils.tensors import randn_, tensor_, zeros_ -from torchjd.sparse._diagonal_sparse_tensor import ( +from torchjd.sparse._aten_function_overrides.einsum import einsum +from torchjd.sparse._aten_function_overrides.pointwise import ( _IN_PLACE_POINTWISE_FUNCTIONS, _POINTWISE_FUNCTIONS, +) +from torchjd.sparse._aten_function_overrides.shape import unsquash_pdim +from torchjd.sparse._diagonal_sparse_tensor import ( DiagonalSparseTensor, - einsum, encode_by_order, fix_ungrouped_dims, get_groupings, - unsquash_pdim, ) From 59bcf06a9dc8e7c056929789f65343ad5ad1fe03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 2 Nov 2025 18:07:31 +0100 Subject: [PATCH 161/182] Rename DiagonalSparseTensor to StructuredSparseTensor --- src/torchjd/autogram/_engine.py | 4 +- src/torchjd/sparse/__init__.py | 2 +- .../_aten_function_overrides/backward.py | 28 +++---- .../sparse/_aten_function_overrides/einsum.py | 84 +++++++++---------- .../_aten_function_overrides/pointwise.py | 38 ++++----- .../sparse/_aten_function_overrides/shape.py | 70 ++++++++-------- ...tensor.py => _structured_sparse_tensor.py} | 24 +++--- ...or.py => test_structured_sparse_tensor.py} | 44 +++++----- 8 files changed, 147 insertions(+), 147 deletions(-) rename src/torchjd/sparse/{_diagonal_sparse_tensor.py => _structured_sparse_tensor.py} (95%) rename tests/unit/sparse/{test_diagonal_sparse_tensor.py => test_structured_sparse_tensor.py} (86%) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 89815713c..dafe362ec 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -4,7 +4,7 @@ from torch import Tensor, nn, vmap from torch.autograd.graph import get_gradient_edge -from torchjd.sparse import make_dst +from torchjd.sparse import make_sst from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator @@ -176,7 +176,7 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: output_dims = list(range(output.ndim)) v_to_ps = [[dim] for dim in output_dims * 2] - jac_output = make_dst(torch.ones_like(output), v_to_ps) + jac_output = make_sst(torch.ones_like(output), v_to_ps) vmapped_diff = differentiation for _ in output_dims: diff --git a/src/torchjd/sparse/__init__.py b/src/torchjd/sparse/__init__.py index c541cc462..7a161b6ad 100644 --- a/src/torchjd/sparse/__init__.py +++ b/src/torchjd/sparse/__init__.py @@ -1,3 +1,3 @@ # Need to import this to execute the code inside and thus to override the functions from . import _aten_function_overrides -from ._diagonal_sparse_tensor import DiagonalSparseTensor, make_dst +from ._structured_sparse_tensor import StructuredSparseTensor, make_sst diff --git a/src/torchjd/sparse/_aten_function_overrides/backward.py b/src/torchjd/sparse/_aten_function_overrides/backward.py index ed5d283ef..dd8e1c1a4 100644 --- a/src/torchjd/sparse/_aten_function_overrides/backward.py +++ b/src/torchjd/sparse/_aten_function_overrides/backward.py @@ -1,36 +1,36 @@ from torch import Tensor from torch.ops import aten # type: ignore -from torchjd.sparse import DiagonalSparseTensor +from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor -@DiagonalSparseTensor.implements(aten.threshold_backward.default) +@StructuredSparseTensor.implements(aten.threshold_backward.default) def threshold_backward_default( - grad_output: DiagonalSparseTensor, self: Tensor, threshold -) -> DiagonalSparseTensor: + grad_output: StructuredSparseTensor, self: Tensor, threshold +) -> StructuredSparseTensor: new_physical = aten.threshold_backward.default(grad_output.physical, self, threshold) - return DiagonalSparseTensor(new_physical, grad_output.v_to_ps) + return StructuredSparseTensor(new_physical, grad_output.v_to_ps) -@DiagonalSparseTensor.implements(aten.hardtanh_backward.default) +@StructuredSparseTensor.implements(aten.hardtanh_backward.default) def hardtanh_backward_default( - grad_output: DiagonalSparseTensor, + grad_output: StructuredSparseTensor, self: Tensor, min_val: Tensor | int | float, max_val: Tensor | int | float, -) -> DiagonalSparseTensor: - if isinstance(self, DiagonalSparseTensor): +) -> StructuredSparseTensor: + if isinstance(self, StructuredSparseTensor): raise NotImplementedError() new_physical = aten.hardtanh_backward.default(grad_output.physical, self, min_val, max_val) - return DiagonalSparseTensor(new_physical, grad_output.v_to_ps) + return StructuredSparseTensor(new_physical, grad_output.v_to_ps) -@DiagonalSparseTensor.implements(aten.hardswish_backward.default) -def hardswish_backward_default(grad_output: DiagonalSparseTensor, self: Tensor): - if isinstance(self, DiagonalSparseTensor): +@StructuredSparseTensor.implements(aten.hardswish_backward.default) +def hardswish_backward_default(grad_output: StructuredSparseTensor, self: Tensor): + if isinstance(self, StructuredSparseTensor): raise NotImplementedError() new_physical = aten.hardswish_backward.default(grad_output.physical, self) - return DiagonalSparseTensor(new_physical, grad_output.v_to_ps) + return StructuredSparseTensor(new_physical, grad_output.v_to_ps) diff --git a/src/torchjd/sparse/_aten_function_overrides/einsum.py b/src/torchjd/sparse/_aten_function_overrides/einsum.py index 0a6a2a318..86b8bac86 100644 --- a/src/torchjd/sparse/_aten_function_overrides/einsum.py +++ b/src/torchjd/sparse/_aten_function_overrides/einsum.py @@ -2,23 +2,23 @@ from torch import Tensor, tensor from torch.ops import aten # type: ignore -from torchjd.sparse import DiagonalSparseTensor -from torchjd.sparse._diagonal_sparse_tensor import ( +from torchjd.sparse._structured_sparse_tensor import ( + StructuredSparseTensor, p_to_vs_from_v_to_ps, - to_diagonal_sparse_tensor, to_most_efficient_tensor, + to_structured_sparse_tensor, ) def prepare_for_elementwise_op( t1: Tensor | int | float, t2: Tensor | int | float -) -> tuple[DiagonalSparseTensor, DiagonalSparseTensor]: +) -> tuple[StructuredSparseTensor, StructuredSparseTensor]: """ - Prepares two DSTs of the same shape from two args, one of those being a DST, and the other being - a DST, Tensor, int or float. + Prepares two SSTs of the same shape from two args, one of those being a SST, and the other being + a SST, Tensor, int or float. """ - assert isinstance(t1, DiagonalSparseTensor) or isinstance(t2, DiagonalSparseTensor) + assert isinstance(t1, StructuredSparseTensor) or isinstance(t2, StructuredSparseTensor) if isinstance(t1, int) or isinstance(t1, float): t1_ = tensor(t1, device=t2.device) @@ -31,13 +31,13 @@ def prepare_for_elementwise_op( t2_ = t2 t1_, t2_ = aten.broadcast_tensors.default([t1_, t2_]) - t1_ = to_diagonal_sparse_tensor(t1_) - t2_ = to_diagonal_sparse_tensor(t2_) + t1_ = to_structured_sparse_tensor(t1_) + t2_ = to_structured_sparse_tensor(t2_) return t1_, t2_ -@DiagonalSparseTensor.implements(aten.mul.Tensor) +@StructuredSparseTensor.implements(aten.mul.Tensor) def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: # Element-wise multiplication with broadcasting t1_, t2_ = prepare_for_elementwise_op(t1, t2) @@ -45,38 +45,38 @@ def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) -@DiagonalSparseTensor.implements(aten.div.Tensor) +@StructuredSparseTensor.implements(aten.div.Tensor) def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: t1_, t2_ = prepare_for_elementwise_op(t1, t2) - t2_ = DiagonalSparseTensor(1.0 / t2_.physical, t2_.v_to_ps) + t2_ = StructuredSparseTensor(1.0 / t2_.physical, t2_.v_to_ps) all_dims = list(range(t1_.ndim)) return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) -@DiagonalSparseTensor.implements(aten.mul.Scalar) -def mul_Scalar(t: DiagonalSparseTensor, scalar) -> DiagonalSparseTensor: - # TODO: maybe it could be that scalar is a scalar DST and t is a normal tensor. Need to check +@StructuredSparseTensor.implements(aten.mul.Scalar) +def mul_Scalar(t: StructuredSparseTensor, scalar) -> StructuredSparseTensor: + # TODO: maybe it could be that scalar is a scalar SST and t is a normal tensor. Need to check # that - assert isinstance(t, DiagonalSparseTensor) + assert isinstance(t, StructuredSparseTensor) new_physical = aten.mul.Scalar(t.physical, scalar) - return DiagonalSparseTensor(new_physical, t.v_to_ps) + return StructuredSparseTensor(new_physical, t.v_to_ps) -@DiagonalSparseTensor.implements(aten.add.Tensor) +@StructuredSparseTensor.implements(aten.add.Tensor) def add_Tensor( t1: Tensor | int | float, t2: Tensor | int | float, alpha: Tensor | float = 1.0 -) -> DiagonalSparseTensor: +) -> StructuredSparseTensor: t1_, t2_ = prepare_for_elementwise_op(t1, t2) if t1_.v_to_ps == t2_.v_to_ps: new_physical = t1_.physical + t2_.physical * alpha - return DiagonalSparseTensor(new_physical, t1_.v_to_ps) + return StructuredSparseTensor(new_physical, t1_.v_to_ps) else: raise NotImplementedError() -def einsum(*args: tuple[DiagonalSparseTensor, list[int]], output: list[int]) -> Tensor: +def einsum(*args: tuple[StructuredSparseTensor, list[int]], output: list[int]) -> Tensor: # First part of the algorithm, determine how to cluster physical indices as well as the common # p_shapes corresponding to matching v_dims. Second part translates to physical einsum. @@ -89,7 +89,7 @@ def einsum(*args: tuple[DiagonalSparseTensor, list[int]], output: list[int]) -> # get unique indices # map output indices (there can be splits) # call physical einsum - # build resulting dst + # build resulting sst # OVER @@ -104,7 +104,7 @@ def einsum(*args: tuple[DiagonalSparseTensor, list[int]], output: list[int]) -> # [p_1, ..., p_k], then we have to create fresh sub-indices for each dimension. # For this reason, an index is decomposed into sub-indices that are then independently # clustered. - # So if an index i in args for some DiagonalSparseTensor corresponds to a v_to_ps [j, k, l], + # So if an index i in args for some StructuredSparseTensor corresponds to a v_to_ps [j, k, l], # We will consider three indices (i, 0), (i, 1) and (i, 2). # If furthermore [k] correspond to the v_to_ps of some other tensor with index j, then # (i, 1) and (j, 0) will be clustered together (and end up being mapped to the same indice in @@ -136,7 +136,7 @@ def group_indices(indices: list[tuple[int, int]]) -> None: tensors = list[Tensor]() indices_to_n_pdims = dict[int, int]() for t, indices in args: - assert isinstance(t, DiagonalSparseTensor) + assert isinstance(t, StructuredSparseTensor) tensors.append(t.physical) for ps, index in zip(t.v_to_ps, indices): if index in indices_to_n_pdims: @@ -150,7 +150,7 @@ def group_indices(indices: list[tuple[int, int]]) -> None: group_indices([(indices[i], sub_i) for i, sub_i in indices_]) # record the physical dimensions, index[v] for v in vs will end-up mapping to the same # final dimension as they were just clustered, so we can take the first, which exists as - # t is a valid DST. + # t is a valid SST. new_indices_pair.append([(indices[vs[0][0]], vs[0][1]) for vs in p_to_vs]) current = 0 @@ -186,9 +186,9 @@ def unique_int(pair: tuple[int, int]) -> int: return to_most_efficient_tensor(physical, v_to_ps) -@DiagonalSparseTensor.implements(aten.bmm.default) +@StructuredSparseTensor.implements(aten.bmm.default) def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor: - assert isinstance(mat1, DiagonalSparseTensor) or isinstance(mat2, DiagonalSparseTensor) + assert isinstance(mat1, StructuredSparseTensor) or isinstance(mat2, StructuredSparseTensor) assert ( mat1.ndim == 3 and mat2.ndim == 3 @@ -196,42 +196,42 @@ def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor: and mat1.shape[2] == mat2.shape[1] ) - mat1_ = to_diagonal_sparse_tensor(mat1) - mat2_ = to_diagonal_sparse_tensor(mat2) + mat1_ = to_structured_sparse_tensor(mat1) + mat2_ = to_structured_sparse_tensor(mat2) # TODO: Verify that the dimension `0` of mat1_ and mat2_ have the same physical dimension sizes # decompositions. If not, can reshape to common decomposition? return einsum((mat1_, [0, 1, 2]), (mat2_, [0, 2, 3]), output=[0, 1, 3]) -@DiagonalSparseTensor.implements(aten.mm.default) +@StructuredSparseTensor.implements(aten.mm.default) def mm_default(mat1: Tensor, mat2: Tensor) -> Tensor: - assert isinstance(mat1, DiagonalSparseTensor) or isinstance(mat2, DiagonalSparseTensor) + assert isinstance(mat1, StructuredSparseTensor) or isinstance(mat2, StructuredSparseTensor) assert mat1.ndim == 2 and mat2.ndim == 2 and mat1.shape[1] == mat2.shape[0] - mat1_ = to_diagonal_sparse_tensor(mat1) - mat2_ = to_diagonal_sparse_tensor(mat2) + mat1_ = to_structured_sparse_tensor(mat1) + mat2_ = to_structured_sparse_tensor(mat2) return einsum((mat1_, [0, 1]), (mat2_, [1, 2]), output=[0, 2]) -@DiagonalSparseTensor.implements(aten.mean.default) -def mean_default(t: DiagonalSparseTensor) -> Tensor: - assert isinstance(t, DiagonalSparseTensor) +@StructuredSparseTensor.implements(aten.mean.default) +def mean_default(t: StructuredSparseTensor) -> Tensor: + assert isinstance(t, StructuredSparseTensor) return aten.sum.default(t.physical) / t.numel() -@DiagonalSparseTensor.implements(aten.sum.default) -def sum_default(t: DiagonalSparseTensor) -> Tensor: - assert isinstance(t, DiagonalSparseTensor) +@StructuredSparseTensor.implements(aten.sum.default) +def sum_default(t: StructuredSparseTensor) -> Tensor: + assert isinstance(t, StructuredSparseTensor) return aten.sum.default(t.physical) -@DiagonalSparseTensor.implements(aten.sum.dim_IntList) +@StructuredSparseTensor.implements(aten.sum.dim_IntList) def sum_dim_IntList( - t: DiagonalSparseTensor, dim: list[int], keepdim: bool = False, dtype=None + t: StructuredSparseTensor, dim: list[int], keepdim: bool = False, dtype=None ) -> Tensor: - assert isinstance(t, DiagonalSparseTensor) + assert isinstance(t, StructuredSparseTensor) if dtype: raise NotImplementedError() diff --git a/src/torchjd/sparse/_aten_function_overrides/pointwise.py b/src/torchjd/sparse/_aten_function_overrides/pointwise.py index 8e0e89f96..c74c79ac8 100644 --- a/src/torchjd/sparse/_aten_function_overrides/pointwise.py +++ b/src/torchjd/sparse/_aten_function_overrides/pointwise.py @@ -1,6 +1,6 @@ from torch.ops import aten # type: ignore -from torchjd.sparse import DiagonalSparseTensor +from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor # pointwise functions applied to one Tensor with `0.0 → 0` _POINTWISE_FUNCTIONS = [ @@ -68,18 +68,18 @@ def _override_pointwise(op): - @DiagonalSparseTensor.implements(op) - def func_(t: DiagonalSparseTensor) -> DiagonalSparseTensor: - assert isinstance(t, DiagonalSparseTensor) - return DiagonalSparseTensor(op(t.physical), t.v_to_ps) + @StructuredSparseTensor.implements(op) + def func_(t: StructuredSparseTensor) -> StructuredSparseTensor: + assert isinstance(t, StructuredSparseTensor) + return StructuredSparseTensor(op(t.physical), t.v_to_ps) return func_ def _override_inplace_pointwise(op): - @DiagonalSparseTensor.implements(op) - def func_(t: DiagonalSparseTensor) -> DiagonalSparseTensor: - assert isinstance(t, DiagonalSparseTensor) + @StructuredSparseTensor.implements(op) + def func_(t: StructuredSparseTensor) -> StructuredSparseTensor: + assert isinstance(t, StructuredSparseTensor) op(t.physical) return t @@ -91,22 +91,22 @@ def func_(t: DiagonalSparseTensor) -> DiagonalSparseTensor: _override_inplace_pointwise(pointwise_func) -@DiagonalSparseTensor.implements(aten.pow.Tensor_Scalar) -def pow_Tensor_Scalar(t: DiagonalSparseTensor, exponent: float) -> DiagonalSparseTensor: - assert isinstance(t, DiagonalSparseTensor) +@StructuredSparseTensor.implements(aten.pow.Tensor_Scalar) +def pow_Tensor_Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredSparseTensor: + assert isinstance(t, StructuredSparseTensor) if exponent <= 0.0: # Need to densify because we don't have pow(0.0, exponent) = 0.0 return aten.pow.Tensor_Scalar(t.to_dense(), exponent) new_physical = aten.pow.Tensor_Scalar(t.physical, exponent) - return DiagonalSparseTensor(new_physical, t.v_to_ps) + return StructuredSparseTensor(new_physical, t.v_to_ps) # Somehow there's no pow_.Tensor_Scalar and pow_.Scalar takes tensor and scalar. -@DiagonalSparseTensor.implements(aten.pow_.Scalar) -def pow__Scalar(t: DiagonalSparseTensor, exponent: float) -> DiagonalSparseTensor: - assert isinstance(t, DiagonalSparseTensor) +@StructuredSparseTensor.implements(aten.pow_.Scalar) +def pow__Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredSparseTensor: + assert isinstance(t, StructuredSparseTensor) if exponent <= 0.0: # Need to densify because we don't have pow(0.0, exponent) = 0.0 @@ -117,9 +117,9 @@ def pow__Scalar(t: DiagonalSparseTensor, exponent: float) -> DiagonalSparseTenso return t -@DiagonalSparseTensor.implements(aten.div.Scalar) -def div_Scalar(t: DiagonalSparseTensor, divisor: float) -> DiagonalSparseTensor: - assert isinstance(t, DiagonalSparseTensor) +@StructuredSparseTensor.implements(aten.div.Scalar) +def div_Scalar(t: StructuredSparseTensor, divisor: float) -> StructuredSparseTensor: + assert isinstance(t, StructuredSparseTensor) new_physical = aten.div.Scalar(t.physical, divisor) - return DiagonalSparseTensor(new_physical, t.v_to_ps) + return StructuredSparseTensor(new_physical, t.v_to_ps) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py index 7e9939770..57946f4f7 100644 --- a/src/torchjd/sparse/_aten_function_overrides/shape.py +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -5,8 +5,8 @@ from torch import Tensor, tensor from torch.ops import aten # type: ignore -from torchjd.sparse import DiagonalSparseTensor -from torchjd.sparse._diagonal_sparse_tensor import ( +from torchjd.sparse._structured_sparse_tensor import ( + StructuredSparseTensor, encode_v_to_ps, fix_dim_encoding, print_fallback, @@ -15,9 +15,9 @@ ) -@DiagonalSparseTensor.implements(aten.view.default) -def view_default(t: DiagonalSparseTensor, shape: list[int]) -> Tensor: - assert isinstance(t, DiagonalSparseTensor) +@StructuredSparseTensor.implements(aten.view.default) +def view_default(t: StructuredSparseTensor, shape: list[int]) -> Tensor: + assert isinstance(t, StructuredSparseTensor) shape = infer_shape(shape, t.numel()) @@ -120,16 +120,16 @@ def new_encoding_fn(d: int) -> list[int]: return new_physical, new_encoding -@DiagonalSparseTensor.implements(aten._unsafe_view.default) -def _unsafe_view_default(t: DiagonalSparseTensor, shape: list[int]) -> Tensor: +@StructuredSparseTensor.implements(aten._unsafe_view.default) +def _unsafe_view_default(t: StructuredSparseTensor, shape: list[int]) -> Tensor: return view_default( t, shape ) # We don't do the optimizations that they do in https://github.com/pytorch/pytorch/blame/main/aten/src/ATen/native/TensorShape.cpp -@DiagonalSparseTensor.implements(aten.unsqueeze.default) -def unsqueeze_default(t: DiagonalSparseTensor, dim: int) -> DiagonalSparseTensor: - assert isinstance(t, DiagonalSparseTensor) +@StructuredSparseTensor.implements(aten.unsqueeze.default) +def unsqueeze_default(t: StructuredSparseTensor, dim: int) -> StructuredSparseTensor: + assert isinstance(t, StructuredSparseTensor) assert -t.ndim - 1 <= dim < t.ndim + 1 if dim < 0: @@ -138,12 +138,12 @@ def unsqueeze_default(t: DiagonalSparseTensor, dim: int) -> DiagonalSparseTensor new_v_to_ps = [p for p in t.v_to_ps] # Deepcopy the list to not modify the original v_to_ps new_v_to_ps.insert(dim, []) - return DiagonalSparseTensor(t.physical, new_v_to_ps) + return StructuredSparseTensor(t.physical, new_v_to_ps) -@DiagonalSparseTensor.implements(aten.squeeze.dims) -def squeeze_dims(t: DiagonalSparseTensor, dims: list[int] | int | None) -> Tensor: - assert isinstance(t, DiagonalSparseTensor) +@StructuredSparseTensor.implements(aten.squeeze.dims) +def squeeze_dims(t: StructuredSparseTensor, dims: list[int] | int | None) -> Tensor: + assert isinstance(t, StructuredSparseTensor) if dims is None: excluded = set(range(t.ndim)) @@ -157,21 +157,21 @@ def squeeze_dims(t: DiagonalSparseTensor, dims: list[int] | int | None) -> Tenso return to_most_efficient_tensor(t.physical, new_v_to_ps) -@DiagonalSparseTensor.implements(aten.permute.default) -def permute_default(t: DiagonalSparseTensor, dims: list[int]) -> DiagonalSparseTensor: +@StructuredSparseTensor.implements(aten.permute.default) +def permute_default(t: StructuredSparseTensor, dims: list[int]) -> StructuredSparseTensor: new_v_to_ps = [t.v_to_ps[d] for d in dims] new_physical, new_v_to_ps = fix_dim_encoding(t.physical, new_v_to_ps) - return DiagonalSparseTensor(new_physical, new_v_to_ps) + return StructuredSparseTensor(new_physical, new_v_to_ps) -@DiagonalSparseTensor.implements(aten.cat.default) +@StructuredSparseTensor.implements(aten.cat.default) def cat_default(tensors: list[Tensor], dim: int) -> Tensor: - if any(not isinstance(t, DiagonalSparseTensor) for t in tensors): + if any(not isinstance(t, StructuredSparseTensor) for t in tensors): print_fallback(aten.cat.default, (tensors, dim), {}) return aten.cat.default([unwrap_to_dense(t) for t in tensors]) - tensors_ = [cast(DiagonalSparseTensor, t) for t in tensors] + tensors_ = [cast(StructuredSparseTensor, t) for t in tensors] ref_tensor = tensors_[0] ref_strides = ref_tensor.strides if any(not torch.equal(t.strides, ref_strides) for t in tensors_[1:]): @@ -210,13 +210,13 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor: new_v_to_ps = ref_tensor.v_to_ps new_physical = aten.cat.default(physicals, dim=pdim) - return DiagonalSparseTensor(new_physical, new_v_to_ps) + return StructuredSparseTensor(new_physical, new_v_to_ps) -@DiagonalSparseTensor.implements(aten.expand.default) -def expand_default(t: DiagonalSparseTensor, sizes: list[int]) -> DiagonalSparseTensor: +@StructuredSparseTensor.implements(aten.expand.default) +def expand_default(t: StructuredSparseTensor, sizes: list[int]) -> StructuredSparseTensor: # note that sizes could also be just an int, or a torch.Size i think - assert isinstance(t, DiagonalSparseTensor) + assert isinstance(t, StructuredSparseTensor) assert isinstance(sizes, list) assert len(sizes) >= t.ndim @@ -245,10 +245,10 @@ def expand_default(t: DiagonalSparseTensor, sizes: list[int]) -> DiagonalSparseT new_v_to_ps, destination = encode_v_to_ps(new_v_to_ps) new_physical = new_physical.movedim(list(range(len(destination))), destination) - return DiagonalSparseTensor(new_physical, new_v_to_ps) + return StructuredSparseTensor(new_physical, new_v_to_ps) -@DiagonalSparseTensor.implements(aten.broadcast_tensors.default) +@StructuredSparseTensor.implements(aten.broadcast_tensors.default) def broadcast_tensors_default(tensors: list[Tensor]) -> tuple[Tensor, Tensor]: if len(tensors) != 2: raise NotImplementedError() @@ -275,11 +275,11 @@ def broadcast_tensors_default(tensors: list[Tensor]) -> tuple[Tensor, Tensor]: return aten.expand.default(t1, new_shape), aten.expand.default(t2, new_shape) -@DiagonalSparseTensor.implements(aten.slice.Tensor) +@StructuredSparseTensor.implements(aten.slice.Tensor) def slice_Tensor( - t: DiagonalSparseTensor, dim: int, start: int | None, end: int | None, step: int = 1 -) -> DiagonalSparseTensor: - assert isinstance(t, DiagonalSparseTensor) + t: StructuredSparseTensor, dim: int, start: int | None, end: int | None, step: int = 1 +) -> StructuredSparseTensor: + assert isinstance(t, StructuredSparseTensor) physical_dims = t.v_to_ps[dim] @@ -308,16 +308,16 @@ def slice_Tensor( physical_dim = physical_dims[0] new_physical = aten.slice.Tensor(t.physical, physical_dim, start, end, step) - return DiagonalSparseTensor(new_physical, t.v_to_ps) + return StructuredSparseTensor(new_physical, t.v_to_ps) -@DiagonalSparseTensor.implements(aten.transpose.int) -def transpose_int(t: DiagonalSparseTensor, dim0: int, dim1: int) -> DiagonalSparseTensor: - assert isinstance(t, DiagonalSparseTensor) +@StructuredSparseTensor.implements(aten.transpose.int) +def transpose_int(t: StructuredSparseTensor, dim0: int, dim1: int) -> StructuredSparseTensor: + assert isinstance(t, StructuredSparseTensor) new_v_to_ps = [dims for dims in t.v_to_ps] new_v_to_ps[dim0] = t.v_to_ps[dim1] new_v_to_ps[dim1] = t.v_to_ps[dim0] new_physical, new_v_to_ps = fix_dim_encoding(t.physical, new_v_to_ps) - return DiagonalSparseTensor(new_physical, new_v_to_ps) + return StructuredSparseTensor(new_physical, new_v_to_ps) diff --git a/src/torchjd/sparse/_diagonal_sparse_tensor.py b/src/torchjd/sparse/_structured_sparse_tensor.py similarity index 95% rename from src/torchjd/sparse/_diagonal_sparse_tensor.py rename to src/torchjd/sparse/_structured_sparse_tensor.py index a051e5c13..33cbb31fa 100644 --- a/src/torchjd/sparse/_diagonal_sparse_tensor.py +++ b/src/torchjd/sparse/_structured_sparse_tensor.py @@ -9,7 +9,7 @@ from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten -class DiagonalSparseTensor(Tensor): +class StructuredSparseTensor(Tensor): _HANDLED_FUNCTIONS = dict() @staticmethod @@ -101,7 +101,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): return func(*unwrapped_args, **unwrapped_kwargs) def __repr__(self, *, tensor_contents=None) -> str: - return f"DiagonalSparseTensor(physical={self.physical}, v_to_ps={self.v_to_ps})" + return f"StructuredSparseTensor(physical={self.physical}, v_to_ps={self.v_to_ps})" def debug_info(self) -> str: info = ( @@ -129,7 +129,7 @@ def decorator(func): def print_fallback(func, args, kwargs) -> None: def tensor_to_str(t: Tensor) -> str: result = f"{t.__class__.__name__} - shape: {t.shape}" - if isinstance(t, DiagonalSparseTensor): + if isinstance(t, StructuredSparseTensor): result += f" - pshape: {t.physical.shape} - v_to_ps: {t.v_to_ps}" return result @@ -166,7 +166,7 @@ def strides_v2(p_dims: list[int], physical_shape: list[int]) -> list[int]: Example: Imagine a vector of size 3, and of value [1, 2, 3]. - Imagine a DST t of shape [3, 3] using this vector as physical and using [[0, 0]] as v_to_ps. + Imagine a SST t of shape [3, 3] using this vector as physical and using [[0, 0]] as v_to_ps. t.to_dense() is [1, 0, 0, 0, 2, 0, 0, 0, 3] (it's the flattening of the diagonal matrix [[1, 0, 0], [0, 2, 0], [0, 0, 3]]). When you move by 1 on physical dimension 0, you move by 4 on virtual dimension 0, i.e. @@ -310,11 +310,11 @@ def encode_v_to_ps(v_to_ps: list[list[int]]) -> tuple[list[list[int]], list[int] return tree_unflatten(sorted_flat_v_to_ps, spec), destination -def to_diagonal_sparse_tensor(t: Tensor) -> DiagonalSparseTensor: - if isinstance(t, DiagonalSparseTensor): +def to_structured_sparse_tensor(t: Tensor) -> StructuredSparseTensor: + if isinstance(t, StructuredSparseTensor): return t else: - return make_dst(t, [[i] for i in range(t.ndim)]) + return make_sst(t, [[i] for i in range(t.ndim)]) def to_most_efficient_tensor(physical: Tensor, v_to_ps: list[list[int]]) -> Tensor: @@ -340,11 +340,11 @@ def to_most_efficient_tensor(physical: Tensor, v_to_ps: list[list[int]]) -> Tens physical, list(range(physical.ndim)), [pdims[0] for pdims in new_v_to_ps] ) else: - return DiagonalSparseTensor(physical, v_to_ps) + return StructuredSparseTensor(physical, v_to_ps) def unwrap_to_dense(t: Tensor): - if isinstance(t, DiagonalSparseTensor): + if isinstance(t, StructuredSparseTensor): return t.to_dense() else: return t @@ -412,10 +412,10 @@ def fix_ungrouped_dims( return nphysical, new_v_to_ps -def make_dst(physical: Tensor, v_to_ps: list[list[int]]) -> DiagonalSparseTensor: - """Fix physical and v_to_ps and create a DiagonalSparseTensor with them.""" +def make_sst(physical: Tensor, v_to_ps: list[list[int]]) -> StructuredSparseTensor: + """Fix physical and v_to_ps and create a StructuredSparseTensor with them.""" physical, v_to_ps = fix_dim_encoding(physical, v_to_ps) physical, v_to_ps = fix_dim_of_size_1(physical, v_to_ps) physical, v_to_ps = fix_ungrouped_dims(physical, v_to_ps) - return DiagonalSparseTensor(physical, v_to_ps) + return StructuredSparseTensor(physical, v_to_ps) diff --git a/tests/unit/sparse/test_diagonal_sparse_tensor.py b/tests/unit/sparse/test_structured_sparse_tensor.py similarity index 86% rename from tests/unit/sparse/test_diagonal_sparse_tensor.py rename to tests/unit/sparse/test_structured_sparse_tensor.py index 05e1c83ba..f4097e538 100644 --- a/tests/unit/sparse/test_diagonal_sparse_tensor.py +++ b/tests/unit/sparse/test_structured_sparse_tensor.py @@ -10,8 +10,8 @@ _POINTWISE_FUNCTIONS, ) from torchjd.sparse._aten_function_overrides.shape import unsquash_pdim -from torchjd.sparse._diagonal_sparse_tensor import ( - DiagonalSparseTensor, +from torchjd.sparse._structured_sparse_tensor import ( + StructuredSparseTensor, encode_by_order, fix_ungrouped_dims, get_groupings, @@ -22,7 +22,7 @@ def test_to_dense(): n = 2 m = 3 a = randn_([n, m]) - b = DiagonalSparseTensor(a, [[0], [1], [1], [0]]) + b = StructuredSparseTensor(a, [[0], [1], [1], [0]]) c = b.to_dense() for i in range(n): @@ -32,7 +32,7 @@ def test_to_dense(): def test_to_dense2(): a = tensor_([1.0, 2.0, 3.0]) - b = DiagonalSparseTensor(a, [[0, 0]]) + b = StructuredSparseTensor(a, [[0, 0]]) c = b.to_dense() expected = tensor_([1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]) assert torch.all(torch.eq(c, expected)) @@ -55,14 +55,14 @@ def test_einsum( b_indices: list[int], output_indices: list[int], ): - a = DiagonalSparseTensor(randn_(a_pshape), a_v_to_ps) - b = DiagonalSparseTensor(randn_(b_pshape), b_v_to_ps) + a = StructuredSparseTensor(randn_(a_pshape), a_v_to_ps) + b = StructuredSparseTensor(randn_(b_pshape), b_v_to_ps) res = einsum((a, a_indices), (b, b_indices), output=output_indices) expected = torch.einsum(a.to_dense(), a_indices, b.to_dense(), b_indices, output_indices) - assert isinstance(res, DiagonalSparseTensor) + assert isinstance(res, StructuredSparseTensor) assert_close(res.to_dense(), expected) @@ -75,9 +75,9 @@ def test_einsum( [2, 3, 4], ], ) -def test_diagonal_sparse_tensor_scalar(shape: list[int]): +def test_structured_sparse_tensor_scalar(shape: list[int]): a = randn_(shape) - b = DiagonalSparseTensor(a, [[dim] for dim in range(len(shape))]) + b = StructuredSparseTensor(a, [[dim] for dim in range(len(shape))]) assert_close(a, b.to_dense()) @@ -85,7 +85,7 @@ def test_diagonal_sparse_tensor_scalar(shape: list[int]): @mark.parametrize("dim", [2, 3, 4, 5, 10]) def test_diag_equivalence(dim: int): a = randn_([dim]) - b = DiagonalSparseTensor(a, [[0], [0]]) + b = StructuredSparseTensor(a, [[0], [0]]) diag_a = torch.diag(a) @@ -95,7 +95,7 @@ def test_diag_equivalence(dim: int): def test_three_virtual_single_physical(): dim = 10 a = randn_([dim]) - b = DiagonalSparseTensor(a, [[0], [0], [0]]) + b = StructuredSparseTensor(a, [[0], [0], [0]]) expected = zeros_([dim, dim, dim]) for i in range(dim): @@ -108,10 +108,10 @@ def test_three_virtual_single_physical(): def test_pointwise(func): dim = 10 a = randn_([dim]) - b = DiagonalSparseTensor(a, [[0], [0]]) + b = StructuredSparseTensor(a, [[0], [0]]) c = b.to_dense() res = func(b) - assert isinstance(res, DiagonalSparseTensor) + assert isinstance(res, StructuredSparseTensor) assert_close(res.to_dense(), func(c), equal_nan=True) @@ -120,10 +120,10 @@ def test_pointwise(func): def test_inplace_pointwise(func): dim = 10 a = randn_([dim]) - b = DiagonalSparseTensor(a, [[0], [0]]) + b = StructuredSparseTensor(a, [[0], [0]]) c = b.to_dense() func(b) - assert isinstance(b, DiagonalSparseTensor) + assert isinstance(b, StructuredSparseTensor) assert_close(b.to_dense(), func(c), equal_nan=True) @@ -132,7 +132,7 @@ def test_inplace_pointwise(func): def test_unary(func): dim = 10 a = randn_([dim]) - b = DiagonalSparseTensor(a, [[0], [0]]) + b = StructuredSparseTensor(a, [[0], [0]]) c = b.to_dense() res = func(b) @@ -163,12 +163,12 @@ def test_view( expected_v_to_ps: list[list[int]], ): a = randn_(tuple(physical_shape)) - t = DiagonalSparseTensor(a, v_to_ps) + t = StructuredSparseTensor(a, v_to_ps) result = aten.view.default(t, target_shape) expected = t.to_dense().reshape(target_shape) - assert isinstance(result, DiagonalSparseTensor) + assert isinstance(result, StructuredSparseTensor) assert list(result.physical.shape) == expected_physical_shape assert result.v_to_ps == expected_v_to_ps assert torch.all(torch.eq(result.to_dense(), expected)) @@ -260,19 +260,19 @@ def test_unsquash_pdim( @mark.parametrize( - ["dst_args", "dim"], + ["sst_args", "dim"], [ ([([3], [[0], [0]]), ([3], [[0], [0]])], 1), ([([3, 2], [[0], [1, 0]]), ([3, 2], [[0], [1, 0]])], 1), ], ) def test_concatenate( - dst_args: list[tuple[list[int], list[list[int]]]], + sst_args: list[tuple[list[int], list[list[int]]]], dim: int, ): - tensors = [DiagonalSparseTensor(randn_(pshape), v_to_ps) for pshape, v_to_ps in dst_args] + tensors = [StructuredSparseTensor(randn_(pshape), v_to_ps) for pshape, v_to_ps in sst_args] res = aten.cat.default(tensors, dim) expected = aten.cat.default([t.to_dense() for t in tensors], dim) - assert isinstance(res, DiagonalSparseTensor) + assert isinstance(res, StructuredSparseTensor) assert torch.all(torch.eq(res.to_dense(), expected)) From f693e99e6bf0d60a04ca9675c385df6d879ba4f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 2 Nov 2025 18:22:01 +0100 Subject: [PATCH 162/182] Improve error message for cat_default --- src/torchjd/sparse/_aten_function_overrides/shape.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py index 57946f4f7..3d39f84c2 100644 --- a/src/torchjd/sparse/_aten_function_overrides/shape.py +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -175,7 +175,11 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor: ref_tensor = tensors_[0] ref_strides = ref_tensor.strides if any(not torch.equal(t.strides, ref_strides) for t in tensors_[1:]): - raise NotImplementedError() + raise NotImplementedError( + "Override for aten.cat.default does not support SSTs that do not all have the same " + f"strides. Found the following strides:\n{[t.strides for t in tensors_]} and the " + f"following dim: {dim}." + ) # We need to try to find the (pretty sure it either does not exist or is unique) physical # dimension that makes us only move on virtual dimension dim. It also needs to be such that From 7991ac16cb09ece225a7e7b473263c349da0fc0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 2 Nov 2025 18:25:57 +0100 Subject: [PATCH 163/182] Add alias impl for StructuedSparseDensity.implements --- .../_aten_function_overrides/backward.py | 8 +++---- .../sparse/_aten_function_overrides/einsum.py | 19 +++++++++-------- .../_aten_function_overrides/pointwise.py | 12 +++++------ .../sparse/_aten_function_overrides/shape.py | 21 ++++++++++--------- .../sparse/_structured_sparse_tensor.py | 3 +++ 5 files changed, 34 insertions(+), 29 deletions(-) diff --git a/src/torchjd/sparse/_aten_function_overrides/backward.py b/src/torchjd/sparse/_aten_function_overrides/backward.py index dd8e1c1a4..77a9fa57f 100644 --- a/src/torchjd/sparse/_aten_function_overrides/backward.py +++ b/src/torchjd/sparse/_aten_function_overrides/backward.py @@ -1,10 +1,10 @@ from torch import Tensor from torch.ops import aten # type: ignore -from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor +from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor, impl -@StructuredSparseTensor.implements(aten.threshold_backward.default) +@impl(aten.threshold_backward.default) def threshold_backward_default( grad_output: StructuredSparseTensor, self: Tensor, threshold ) -> StructuredSparseTensor: @@ -13,7 +13,7 @@ def threshold_backward_default( return StructuredSparseTensor(new_physical, grad_output.v_to_ps) -@StructuredSparseTensor.implements(aten.hardtanh_backward.default) +@impl(aten.hardtanh_backward.default) def hardtanh_backward_default( grad_output: StructuredSparseTensor, self: Tensor, @@ -27,7 +27,7 @@ def hardtanh_backward_default( return StructuredSparseTensor(new_physical, grad_output.v_to_ps) -@StructuredSparseTensor.implements(aten.hardswish_backward.default) +@impl(aten.hardswish_backward.default) def hardswish_backward_default(grad_output: StructuredSparseTensor, self: Tensor): if isinstance(self, StructuredSparseTensor): raise NotImplementedError() diff --git a/src/torchjd/sparse/_aten_function_overrides/einsum.py b/src/torchjd/sparse/_aten_function_overrides/einsum.py index 86b8bac86..4732292be 100644 --- a/src/torchjd/sparse/_aten_function_overrides/einsum.py +++ b/src/torchjd/sparse/_aten_function_overrides/einsum.py @@ -4,6 +4,7 @@ from torchjd.sparse._structured_sparse_tensor import ( StructuredSparseTensor, + impl, p_to_vs_from_v_to_ps, to_most_efficient_tensor, to_structured_sparse_tensor, @@ -37,7 +38,7 @@ def prepare_for_elementwise_op( return t1_, t2_ -@StructuredSparseTensor.implements(aten.mul.Tensor) +@impl(aten.mul.Tensor) def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: # Element-wise multiplication with broadcasting t1_, t2_ = prepare_for_elementwise_op(t1, t2) @@ -45,7 +46,7 @@ def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) -@StructuredSparseTensor.implements(aten.div.Tensor) +@impl(aten.div.Tensor) def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: t1_, t2_ = prepare_for_elementwise_op(t1, t2) t2_ = StructuredSparseTensor(1.0 / t2_.physical, t2_.v_to_ps) @@ -53,7 +54,7 @@ def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) -@StructuredSparseTensor.implements(aten.mul.Scalar) +@impl(aten.mul.Scalar) def mul_Scalar(t: StructuredSparseTensor, scalar) -> StructuredSparseTensor: # TODO: maybe it could be that scalar is a scalar SST and t is a normal tensor. Need to check # that @@ -63,7 +64,7 @@ def mul_Scalar(t: StructuredSparseTensor, scalar) -> StructuredSparseTensor: return StructuredSparseTensor(new_physical, t.v_to_ps) -@StructuredSparseTensor.implements(aten.add.Tensor) +@impl(aten.add.Tensor) def add_Tensor( t1: Tensor | int | float, t2: Tensor | int | float, alpha: Tensor | float = 1.0 ) -> StructuredSparseTensor: @@ -186,7 +187,7 @@ def unique_int(pair: tuple[int, int]) -> int: return to_most_efficient_tensor(physical, v_to_ps) -@StructuredSparseTensor.implements(aten.bmm.default) +@impl(aten.bmm.default) def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor: assert isinstance(mat1, StructuredSparseTensor) or isinstance(mat2, StructuredSparseTensor) assert ( @@ -204,7 +205,7 @@ def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor: return einsum((mat1_, [0, 1, 2]), (mat2_, [0, 2, 3]), output=[0, 1, 3]) -@StructuredSparseTensor.implements(aten.mm.default) +@impl(aten.mm.default) def mm_default(mat1: Tensor, mat2: Tensor) -> Tensor: assert isinstance(mat1, StructuredSparseTensor) or isinstance(mat2, StructuredSparseTensor) assert mat1.ndim == 2 and mat2.ndim == 2 and mat1.shape[1] == mat2.shape[0] @@ -215,19 +216,19 @@ def mm_default(mat1: Tensor, mat2: Tensor) -> Tensor: return einsum((mat1_, [0, 1]), (mat2_, [1, 2]), output=[0, 2]) -@StructuredSparseTensor.implements(aten.mean.default) +@impl(aten.mean.default) def mean_default(t: StructuredSparseTensor) -> Tensor: assert isinstance(t, StructuredSparseTensor) return aten.sum.default(t.physical) / t.numel() -@StructuredSparseTensor.implements(aten.sum.default) +@impl(aten.sum.default) def sum_default(t: StructuredSparseTensor) -> Tensor: assert isinstance(t, StructuredSparseTensor) return aten.sum.default(t.physical) -@StructuredSparseTensor.implements(aten.sum.dim_IntList) +@impl(aten.sum.dim_IntList) def sum_dim_IntList( t: StructuredSparseTensor, dim: list[int], keepdim: bool = False, dtype=None ) -> Tensor: diff --git a/src/torchjd/sparse/_aten_function_overrides/pointwise.py b/src/torchjd/sparse/_aten_function_overrides/pointwise.py index c74c79ac8..85798c65f 100644 --- a/src/torchjd/sparse/_aten_function_overrides/pointwise.py +++ b/src/torchjd/sparse/_aten_function_overrides/pointwise.py @@ -1,6 +1,6 @@ from torch.ops import aten # type: ignore -from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor +from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor, impl # pointwise functions applied to one Tensor with `0.0 → 0` _POINTWISE_FUNCTIONS = [ @@ -68,7 +68,7 @@ def _override_pointwise(op): - @StructuredSparseTensor.implements(op) + @impl(op) def func_(t: StructuredSparseTensor) -> StructuredSparseTensor: assert isinstance(t, StructuredSparseTensor) return StructuredSparseTensor(op(t.physical), t.v_to_ps) @@ -77,7 +77,7 @@ def func_(t: StructuredSparseTensor) -> StructuredSparseTensor: def _override_inplace_pointwise(op): - @StructuredSparseTensor.implements(op) + @impl(op) def func_(t: StructuredSparseTensor) -> StructuredSparseTensor: assert isinstance(t, StructuredSparseTensor) op(t.physical) @@ -91,7 +91,7 @@ def func_(t: StructuredSparseTensor) -> StructuredSparseTensor: _override_inplace_pointwise(pointwise_func) -@StructuredSparseTensor.implements(aten.pow.Tensor_Scalar) +@impl(aten.pow.Tensor_Scalar) def pow_Tensor_Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredSparseTensor: assert isinstance(t, StructuredSparseTensor) @@ -104,7 +104,7 @@ def pow_Tensor_Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredS # Somehow there's no pow_.Tensor_Scalar and pow_.Scalar takes tensor and scalar. -@StructuredSparseTensor.implements(aten.pow_.Scalar) +@impl(aten.pow_.Scalar) def pow__Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredSparseTensor: assert isinstance(t, StructuredSparseTensor) @@ -117,7 +117,7 @@ def pow__Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredSparseT return t -@StructuredSparseTensor.implements(aten.div.Scalar) +@impl(aten.div.Scalar) def div_Scalar(t: StructuredSparseTensor, divisor: float) -> StructuredSparseTensor: assert isinstance(t, StructuredSparseTensor) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py index 3d39f84c2..c697f94b9 100644 --- a/src/torchjd/sparse/_aten_function_overrides/shape.py +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -9,13 +9,14 @@ StructuredSparseTensor, encode_v_to_ps, fix_dim_encoding, + impl, print_fallback, to_most_efficient_tensor, unwrap_to_dense, ) -@StructuredSparseTensor.implements(aten.view.default) +@impl(aten.view.default) def view_default(t: StructuredSparseTensor, shape: list[int]) -> Tensor: assert isinstance(t, StructuredSparseTensor) @@ -120,14 +121,14 @@ def new_encoding_fn(d: int) -> list[int]: return new_physical, new_encoding -@StructuredSparseTensor.implements(aten._unsafe_view.default) +@impl(aten._unsafe_view.default) def _unsafe_view_default(t: StructuredSparseTensor, shape: list[int]) -> Tensor: return view_default( t, shape ) # We don't do the optimizations that they do in https://github.com/pytorch/pytorch/blame/main/aten/src/ATen/native/TensorShape.cpp -@StructuredSparseTensor.implements(aten.unsqueeze.default) +@impl(aten.unsqueeze.default) def unsqueeze_default(t: StructuredSparseTensor, dim: int) -> StructuredSparseTensor: assert isinstance(t, StructuredSparseTensor) assert -t.ndim - 1 <= dim < t.ndim + 1 @@ -141,7 +142,7 @@ def unsqueeze_default(t: StructuredSparseTensor, dim: int) -> StructuredSparseTe return StructuredSparseTensor(t.physical, new_v_to_ps) -@StructuredSparseTensor.implements(aten.squeeze.dims) +@impl(aten.squeeze.dims) def squeeze_dims(t: StructuredSparseTensor, dims: list[int] | int | None) -> Tensor: assert isinstance(t, StructuredSparseTensor) @@ -157,7 +158,7 @@ def squeeze_dims(t: StructuredSparseTensor, dims: list[int] | int | None) -> Ten return to_most_efficient_tensor(t.physical, new_v_to_ps) -@StructuredSparseTensor.implements(aten.permute.default) +@impl(aten.permute.default) def permute_default(t: StructuredSparseTensor, dims: list[int]) -> StructuredSparseTensor: new_v_to_ps = [t.v_to_ps[d] for d in dims] @@ -165,7 +166,7 @@ def permute_default(t: StructuredSparseTensor, dims: list[int]) -> StructuredSpa return StructuredSparseTensor(new_physical, new_v_to_ps) -@StructuredSparseTensor.implements(aten.cat.default) +@impl(aten.cat.default) def cat_default(tensors: list[Tensor], dim: int) -> Tensor: if any(not isinstance(t, StructuredSparseTensor) for t in tensors): print_fallback(aten.cat.default, (tensors, dim), {}) @@ -217,7 +218,7 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor: return StructuredSparseTensor(new_physical, new_v_to_ps) -@StructuredSparseTensor.implements(aten.expand.default) +@impl(aten.expand.default) def expand_default(t: StructuredSparseTensor, sizes: list[int]) -> StructuredSparseTensor: # note that sizes could also be just an int, or a torch.Size i think assert isinstance(t, StructuredSparseTensor) @@ -252,7 +253,7 @@ def expand_default(t: StructuredSparseTensor, sizes: list[int]) -> StructuredSpa return StructuredSparseTensor(new_physical, new_v_to_ps) -@StructuredSparseTensor.implements(aten.broadcast_tensors.default) +@impl(aten.broadcast_tensors.default) def broadcast_tensors_default(tensors: list[Tensor]) -> tuple[Tensor, Tensor]: if len(tensors) != 2: raise NotImplementedError() @@ -279,7 +280,7 @@ def broadcast_tensors_default(tensors: list[Tensor]) -> tuple[Tensor, Tensor]: return aten.expand.default(t1, new_shape), aten.expand.default(t2, new_shape) -@StructuredSparseTensor.implements(aten.slice.Tensor) +@impl(aten.slice.Tensor) def slice_Tensor( t: StructuredSparseTensor, dim: int, start: int | None, end: int | None, step: int = 1 ) -> StructuredSparseTensor: @@ -315,7 +316,7 @@ def slice_Tensor( return StructuredSparseTensor(new_physical, t.v_to_ps) -@StructuredSparseTensor.implements(aten.transpose.int) +@impl(aten.transpose.int) def transpose_int(t: StructuredSparseTensor, dim0: int, dim1: int) -> StructuredSparseTensor: assert isinstance(t, StructuredSparseTensor) diff --git a/src/torchjd/sparse/_structured_sparse_tensor.py b/src/torchjd/sparse/_structured_sparse_tensor.py index 33cbb31fa..f438a4214 100644 --- a/src/torchjd/sparse/_structured_sparse_tensor.py +++ b/src/torchjd/sparse/_structured_sparse_tensor.py @@ -126,6 +126,9 @@ def decorator(func): return decorator +impl = StructuredSparseTensor.implements + + def print_fallback(func, args, kwargs) -> None: def tensor_to_str(t: Tensor) -> str: result = f"{t.__class__.__name__} - shape: {t.shape}" From 93e3a6137a047bde2f15378618b62f80521f30cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 2 Nov 2025 18:37:26 +0100 Subject: [PATCH 164/182] Improve error message in cat_default --- src/torchjd/sparse/_aten_function_overrides/shape.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py index c697f94b9..2c8c67706 100644 --- a/src/torchjd/sparse/_aten_function_overrides/shape.py +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -178,7 +178,7 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor: if any(not torch.equal(t.strides, ref_strides) for t in tensors_[1:]): raise NotImplementedError( "Override for aten.cat.default does not support SSTs that do not all have the same " - f"strides. Found the following strides:\n{[t.strides for t in tensors_]} and the " + f"strides. Found the following tensors:\n{[t.debug_info() for t in tensors_]} and the " f"following dim: {dim}." ) From 3c3b4a43f8a5a3ac3e2612d71d396ce01bcdbd03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 3 Nov 2025 06:21:09 +0100 Subject: [PATCH 165/182] Rename variables in einsum --- .../sparse/_aten_function_overrides/einsum.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/torchjd/sparse/_aten_function_overrides/einsum.py b/src/torchjd/sparse/_aten_function_overrides/einsum.py index 4732292be..01643b7c0 100644 --- a/src/torchjd/sparse/_aten_function_overrides/einsum.py +++ b/src/torchjd/sparse/_aten_function_overrides/einsum.py @@ -111,7 +111,7 @@ def einsum(*args: tuple[StructuredSparseTensor, list[int]], output: list[int]) - # (i, 1) and (j, 0) will be clustered together (and end up being mapped to the same indice in # the resulting einsum). # Note that this is a problem if two virtual dimensions (from possibly different - # DiagonaSparseTensors) have the same size but not the same decomposition into physical + # StructuredSparseTensors) have the same size but not the same decomposition into physical # dimension sizes. For now lets leave the responsibility to care about that in the calling # functions, if we can factor code later on we will. @@ -134,16 +134,16 @@ def group_indices(indices: list[tuple[int, int]]) -> None: index_parents[curr_representative] = first_representative new_indices_pair = list[list[tuple[int, int]]]() - tensors = list[Tensor]() + physicals = list[Tensor]() indices_to_n_pdims = dict[int, int]() for t, indices in args: assert isinstance(t, StructuredSparseTensor) - tensors.append(t.physical) - for ps, index in zip(t.v_to_ps, indices): + physicals.append(t.physical) + for pdims, index in zip(t.v_to_ps, indices): if index in indices_to_n_pdims: - assert indices_to_n_pdims[index] == len(ps) + assert indices_to_n_pdims[index] == len(pdims) else: - indices_to_n_pdims[index] = len(ps) + indices_to_n_pdims[index] = len(pdims) p_to_vs = p_to_vs_from_v_to_ps(t.v_to_ps) for indices_ in p_to_vs: # elements in indices[indices_] map to the same dimension, they should be clustered @@ -181,7 +181,7 @@ def unique_int(pair: tuple[int, int]) -> int: new_output.append(k) v_to_ps.append(current_v_to_ps) - physical = torch.einsum(*[x for y in zip(tensors, new_indices) for x in y], new_output) + physical = torch.einsum(*[x for y in zip(physicals, new_indices) for x in y], new_output) # Need to use the safe constructor, otherwise the dimensions may not be maximally grouped. # Maybe there is a way to fix that though. return to_most_efficient_tensor(physical, v_to_ps) From b303501bcef1a643c16f8c3d73a69bbe8357e08d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 3 Nov 2025 06:28:40 +0100 Subject: [PATCH 166/182] Improve error message in einsum --- src/torchjd/sparse/_aten_function_overrides/einsum.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_aten_function_overrides/einsum.py b/src/torchjd/sparse/_aten_function_overrides/einsum.py index 01643b7c0..df7f559cc 100644 --- a/src/torchjd/sparse/_aten_function_overrides/einsum.py +++ b/src/torchjd/sparse/_aten_function_overrides/einsum.py @@ -141,7 +141,13 @@ def group_indices(indices: list[tuple[int, int]]) -> None: physicals.append(t.physical) for pdims, index in zip(t.v_to_ps, indices): if index in indices_to_n_pdims: - assert indices_to_n_pdims[index] == len(pdims) + if indices_to_n_pdims[index] != len(pdims): + raise NotImplementedError( + "einsum currently does not support having a different number of physical " + "dimensions corresponding to matching virtual dimensions of different " + f"tensors. Found {[(t.debug_info(), indices) for t, indices in args]}, " + f"output_indices={output}." + ) else: indices_to_n_pdims[index] = len(pdims) p_to_vs = p_to_vs_from_v_to_ps(t.v_to_ps) From 81dd29e11b7dc853386d82fbb77aa69f23d2c8d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 3 Nov 2025 06:53:39 +0100 Subject: [PATCH 167/182] Add failing parametrization of test_einsum --- tests/unit/sparse/test_structured_sparse_tensor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/sparse/test_structured_sparse_tensor.py b/tests/unit/sparse/test_structured_sparse_tensor.py index f4097e538..7c93ba0a4 100644 --- a/tests/unit/sparse/test_structured_sparse_tensor.py +++ b/tests/unit/sparse/test_structured_sparse_tensor.py @@ -44,6 +44,7 @@ def test_to_dense2(): ([4, 5], [[0], [0], [1]], [4, 5], [[0], [1], [1]], [0, 1, 2], [0, 2, 3], [0, 1, 3]), ([2, 3, 5], [[0, 1], [2, 0]], [10, 3], [[0], [1]], [0, 1], [1, 2], [0, 2]), ([2, 3], [[0, 1]], [6], [[0]], [0], [0], []), + ([6, 2, 3], [[0], [1], [2]], [2, 3], [[0, 1], [0], [1]], [0, 1, 2], [0, 1, 2], [0, 1, 2]), ], ) def test_einsum( From 8593a7414511b47dcbc6acd7de801a0c2acfb7c8 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 3 Nov 2025 15:14:07 +0100 Subject: [PATCH 168/182] Move `einsum` to the top of the file `einsum.py` --- .../sparse/_aten_function_overrides/einsum.py | 132 +++++++++--------- 1 file changed, 66 insertions(+), 66 deletions(-) diff --git a/src/torchjd/sparse/_aten_function_overrides/einsum.py b/src/torchjd/sparse/_aten_function_overrides/einsum.py index df7f559cc..89a9f32a4 100644 --- a/src/torchjd/sparse/_aten_function_overrides/einsum.py +++ b/src/torchjd/sparse/_aten_function_overrides/einsum.py @@ -11,72 +11,6 @@ ) -def prepare_for_elementwise_op( - t1: Tensor | int | float, t2: Tensor | int | float -) -> tuple[StructuredSparseTensor, StructuredSparseTensor]: - """ - Prepares two SSTs of the same shape from two args, one of those being a SST, and the other being - a SST, Tensor, int or float. - """ - - assert isinstance(t1, StructuredSparseTensor) or isinstance(t2, StructuredSparseTensor) - - if isinstance(t1, int) or isinstance(t1, float): - t1_ = tensor(t1, device=t2.device) - else: - t1_ = t1 - - if isinstance(t2, int) or isinstance(t2, float): - t2_ = tensor(t2, device=t1.device) - else: - t2_ = t2 - - t1_, t2_ = aten.broadcast_tensors.default([t1_, t2_]) - t1_ = to_structured_sparse_tensor(t1_) - t2_ = to_structured_sparse_tensor(t2_) - - return t1_, t2_ - - -@impl(aten.mul.Tensor) -def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: - # Element-wise multiplication with broadcasting - t1_, t2_ = prepare_for_elementwise_op(t1, t2) - all_dims = list(range(t1_.ndim)) - return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) - - -@impl(aten.div.Tensor) -def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: - t1_, t2_ = prepare_for_elementwise_op(t1, t2) - t2_ = StructuredSparseTensor(1.0 / t2_.physical, t2_.v_to_ps) - all_dims = list(range(t1_.ndim)) - return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) - - -@impl(aten.mul.Scalar) -def mul_Scalar(t: StructuredSparseTensor, scalar) -> StructuredSparseTensor: - # TODO: maybe it could be that scalar is a scalar SST and t is a normal tensor. Need to check - # that - - assert isinstance(t, StructuredSparseTensor) - new_physical = aten.mul.Scalar(t.physical, scalar) - return StructuredSparseTensor(new_physical, t.v_to_ps) - - -@impl(aten.add.Tensor) -def add_Tensor( - t1: Tensor | int | float, t2: Tensor | int | float, alpha: Tensor | float = 1.0 -) -> StructuredSparseTensor: - t1_, t2_ = prepare_for_elementwise_op(t1, t2) - - if t1_.v_to_ps == t2_.v_to_ps: - new_physical = t1_.physical + t2_.physical * alpha - return StructuredSparseTensor(new_physical, t1_.v_to_ps) - else: - raise NotImplementedError() - - def einsum(*args: tuple[StructuredSparseTensor, list[int]], output: list[int]) -> Tensor: # First part of the algorithm, determine how to cluster physical indices as well as the common @@ -193,6 +127,72 @@ def unique_int(pair: tuple[int, int]) -> int: return to_most_efficient_tensor(physical, v_to_ps) +def prepare_for_elementwise_op( + t1: Tensor | int | float, t2: Tensor | int | float +) -> tuple[StructuredSparseTensor, StructuredSparseTensor]: + """ + Prepares two SSTs of the same shape from two args, one of those being a SST, and the other being + a SST, Tensor, int or float. + """ + + assert isinstance(t1, StructuredSparseTensor) or isinstance(t2, StructuredSparseTensor) + + if isinstance(t1, int) or isinstance(t1, float): + t1_ = tensor(t1, device=t2.device) + else: + t1_ = t1 + + if isinstance(t2, int) or isinstance(t2, float): + t2_ = tensor(t2, device=t1.device) + else: + t2_ = t2 + + t1_, t2_ = aten.broadcast_tensors.default([t1_, t2_]) + t1_ = to_structured_sparse_tensor(t1_) + t2_ = to_structured_sparse_tensor(t2_) + + return t1_, t2_ + + +@impl(aten.mul.Tensor) +def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: + # Element-wise multiplication with broadcasting + t1_, t2_ = prepare_for_elementwise_op(t1, t2) + all_dims = list(range(t1_.ndim)) + return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) + + +@impl(aten.div.Tensor) +def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: + t1_, t2_ = prepare_for_elementwise_op(t1, t2) + t2_ = StructuredSparseTensor(1.0 / t2_.physical, t2_.v_to_ps) + all_dims = list(range(t1_.ndim)) + return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) + + +@impl(aten.mul.Scalar) +def mul_Scalar(t: StructuredSparseTensor, scalar) -> StructuredSparseTensor: + # TODO: maybe it could be that scalar is a scalar SST and t is a normal tensor. Need to check + # that + + assert isinstance(t, StructuredSparseTensor) + new_physical = aten.mul.Scalar(t.physical, scalar) + return StructuredSparseTensor(new_physical, t.v_to_ps) + + +@impl(aten.add.Tensor) +def add_Tensor( + t1: Tensor | int | float, t2: Tensor | int | float, alpha: Tensor | float = 1.0 +) -> StructuredSparseTensor: + t1_, t2_ = prepare_for_elementwise_op(t1, t2) + + if t1_.v_to_ps == t2_.v_to_ps: + new_physical = t1_.physical + t2_.physical * alpha + return StructuredSparseTensor(new_physical, t1_.v_to_ps) + else: + raise NotImplementedError() + + @impl(aten.bmm.default) def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor: assert isinstance(mat1, StructuredSparseTensor) or isinstance(mat2, StructuredSparseTensor) From 4ccef3a0db2ca958d99a3247678dbe160a0452c5 Mon Sep 17 00:00:00 2001 From: Matthieu Buot de l'Epine Date: Wed, 5 Nov 2025 11:59:04 +0100 Subject: [PATCH 169/182] clear_null_stride_columns --- .../sparse/_structured_sparse_tensor.py | 13 +++++++++ .../sparse/test_structured_sparse_tensor.py | 29 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/src/torchjd/sparse/_structured_sparse_tensor.py b/src/torchjd/sparse/_structured_sparse_tensor.py index f438a4214..f5ab93b8c 100644 --- a/src/torchjd/sparse/_structured_sparse_tensor.py +++ b/src/torchjd/sparse/_structured_sparse_tensor.py @@ -422,3 +422,16 @@ def make_sst(physical: Tensor, v_to_ps: list[list[int]]) -> StructuredSparseTens physical, v_to_ps = fix_dim_of_size_1(physical, v_to_ps) physical, v_to_ps = fix_ungrouped_dims(physical, v_to_ps) return StructuredSparseTensor(physical, v_to_ps) + + +def clear_null_stride_columns(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tensor]: + """Remove columns of strides that are all 0 and sum the corresponding elements in the physical tensor.""" + all_zero_columns = (strides == 0).all(dim=0) + + if not (all_zero_columns).any(): + return physical, strides + + all_zero_columns_indices = all_zero_columns.nonzero().flatten().tolist() + physical = physical.sum(dim=all_zero_columns_indices) + strides = strides[:, ~all_zero_columns] + return physical, strides diff --git a/tests/unit/sparse/test_structured_sparse_tensor.py b/tests/unit/sparse/test_structured_sparse_tensor.py index 7c93ba0a4..61d7efb67 100644 --- a/tests/unit/sparse/test_structured_sparse_tensor.py +++ b/tests/unit/sparse/test_structured_sparse_tensor.py @@ -12,6 +12,7 @@ from torchjd.sparse._aten_function_overrides.shape import unsquash_pdim from torchjd.sparse._structured_sparse_tensor import ( StructuredSparseTensor, + clear_null_stride_columns, encode_by_order, fix_ungrouped_dims, get_groupings, @@ -277,3 +278,31 @@ def test_concatenate( assert isinstance(res, StructuredSparseTensor) assert torch.all(torch.eq(res.to_dense(), expected)) + + +@mark.parametrize( + ["physical", "strides", "expected_physical", "expected_strides"], + [ + ([[1, 2, 3], [4, 5, 6]], [[1, 0], [1, 0], [2, 0]], [6, 15], [[1], [1], [2]]), + ( + [[1, 2, 3], [4, 5, 6]], + [[1, 1], [1, 0], [2, 0]], + [[1, 2, 3], [4, 5, 6]], + [[1, 1], [1, 0], [2, 0]], + ), + ], +) +def test_clear_null_stride_columns( + physical: list, + strides: list, + expected_physical: list, + expected_strides: list, +): + physical, strides = torch.tensor(physical), torch.tensor(strides) + expected_physical, expected_strides = torch.tensor(expected_physical), torch.tensor( + expected_strides + ) + + physical, strides = clear_null_stride_columns(physical, strides) + assert_close(physical, expected_physical) + assert_close(strides, expected_strides) From 85dc26be72a50200b10faf2e33143be582e50577 Mon Sep 17 00:00:00 2001 From: Matthieu Buot de l'Epine Date: Wed, 5 Nov 2025 12:37:26 +0100 Subject: [PATCH 170/182] refacto fix zero stride columns --- .../sparse/_structured_sparse_tensor.py | 12 ++--- .../sparse/test_structured_sparse_tensor.py | 45 +++++++++++-------- 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/src/torchjd/sparse/_structured_sparse_tensor.py b/src/torchjd/sparse/_structured_sparse_tensor.py index f5ab93b8c..daaa0f902 100644 --- a/src/torchjd/sparse/_structured_sparse_tensor.py +++ b/src/torchjd/sparse/_structured_sparse_tensor.py @@ -424,14 +424,14 @@ def make_sst(physical: Tensor, v_to_ps: list[list[int]]) -> StructuredSparseTens return StructuredSparseTensor(physical, v_to_ps) -def clear_null_stride_columns(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tensor]: +def fix_zero_stride_columns(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tensor]: """Remove columns of strides that are all 0 and sum the corresponding elements in the physical tensor.""" - all_zero_columns = (strides == 0).all(dim=0) + are_columns_zero = (strides == 0).all(dim=0) - if not (all_zero_columns).any(): + if not (are_columns_zero).any(): return physical, strides - all_zero_columns_indices = all_zero_columns.nonzero().flatten().tolist() - physical = physical.sum(dim=all_zero_columns_indices) - strides = strides[:, ~all_zero_columns] + zero_column_indices = torch.arange(len(are_columns_zero))[are_columns_zero].tolist() + physical = physical.sum(dim=zero_column_indices) + strides = strides[:, ~are_columns_zero] return physical, strides diff --git a/tests/unit/sparse/test_structured_sparse_tensor.py b/tests/unit/sparse/test_structured_sparse_tensor.py index 61d7efb67..ea0055298 100644 --- a/tests/unit/sparse/test_structured_sparse_tensor.py +++ b/tests/unit/sparse/test_structured_sparse_tensor.py @@ -1,5 +1,6 @@ import torch from pytest import mark +from torch import Tensor, tensor from torch.ops import aten # type: ignore from torch.testing import assert_close from utils.tensors import randn_, tensor_, zeros_ @@ -12,9 +13,9 @@ from torchjd.sparse._aten_function_overrides.shape import unsquash_pdim from torchjd.sparse._structured_sparse_tensor import ( StructuredSparseTensor, - clear_null_stride_columns, encode_by_order, fix_ungrouped_dims, + fix_zero_stride_columns, get_groupings, ) @@ -283,26 +284,32 @@ def test_concatenate( @mark.parametrize( ["physical", "strides", "expected_physical", "expected_strides"], [ - ([[1, 2, 3], [4, 5, 6]], [[1, 0], [1, 0], [2, 0]], [6, 15], [[1], [1], [2]]), ( - [[1, 2, 3], [4, 5, 6]], - [[1, 1], [1, 0], [2, 0]], - [[1, 2, 3], [4, 5, 6]], - [[1, 1], [1, 0], [2, 0]], + tensor_([[1, 2, 3], [4, 5, 6]]), + tensor([[1, 0], [1, 0], [2, 0]]), + tensor_([6, 15]), + tensor([[1], [1], [2]]), + ), + ( + tensor_([[1, 2, 3], [4, 5, 6]]), + tensor([[1, 1], [1, 0], [2, 0]]), + tensor_([[1, 2, 3], [4, 5, 6]]), + tensor([[1, 1], [1, 0], [2, 0]]), + ), + ( + tensor_([[3, 2, 1], [6, 5, 4]]), + tensor([[0, 0], [0, 0], [0, 0]]), + tensor_(21), + tensor([[], [], []], dtype=torch.int64), ), ], ) -def test_clear_null_stride_columns( - physical: list, - strides: list, - expected_physical: list, - expected_strides: list, +def test_fix_zero_stride_columns( + physical: Tensor, + strides: Tensor, + expected_physical: Tensor, + expected_strides: Tensor, ): - physical, strides = torch.tensor(physical), torch.tensor(strides) - expected_physical, expected_strides = torch.tensor(expected_physical), torch.tensor( - expected_strides - ) - - physical, strides = clear_null_stride_columns(physical, strides) - assert_close(physical, expected_physical) - assert_close(strides, expected_strides) + physical, strides = fix_zero_stride_columns(physical, strides) + assert torch.equal(physical, expected_physical) + assert torch.equal(strides, expected_strides) From 1c73416217f6e997e9d68b5320a42e31b9c2cbd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 3 Nov 2025 09:44:47 +0100 Subject: [PATCH 171/182] Add jupyter notebooks to .gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 26ecc8b38..0b7d1aa67 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Jupyter notebooks +*.ipynb + # uv uv.lock From 5547ff64999e3fbe784e1f33b2d7f508b1fe8c2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 5 Nov 2025 17:10:44 +0100 Subject: [PATCH 172/182] Use strides-based representation instead of v_to_ps-based * still need to update einsum and view, update tests, debug, and maybe delete some unused functions. --- .../_aten_function_overrides/backward.py | 6 +- .../sparse/_aten_function_overrides/einsum.py | 11 +- .../_aten_function_overrides/pointwise.py | 6 +- .../sparse/_aten_function_overrides/shape.py | 114 ++----- .../sparse/_structured_sparse_tensor.py | 301 +++++------------- 5 files changed, 118 insertions(+), 320 deletions(-) diff --git a/src/torchjd/sparse/_aten_function_overrides/backward.py b/src/torchjd/sparse/_aten_function_overrides/backward.py index 77a9fa57f..9168c7653 100644 --- a/src/torchjd/sparse/_aten_function_overrides/backward.py +++ b/src/torchjd/sparse/_aten_function_overrides/backward.py @@ -10,7 +10,7 @@ def threshold_backward_default( ) -> StructuredSparseTensor: new_physical = aten.threshold_backward.default(grad_output.physical, self, threshold) - return StructuredSparseTensor(new_physical, grad_output.v_to_ps) + return StructuredSparseTensor(new_physical, grad_output.strides) @impl(aten.hardtanh_backward.default) @@ -24,7 +24,7 @@ def hardtanh_backward_default( raise NotImplementedError() new_physical = aten.hardtanh_backward.default(grad_output.physical, self, min_val, max_val) - return StructuredSparseTensor(new_physical, grad_output.v_to_ps) + return StructuredSparseTensor(new_physical, grad_output.strides) @impl(aten.hardswish_backward.default) @@ -33,4 +33,4 @@ def hardswish_backward_default(grad_output: StructuredSparseTensor, self: Tensor raise NotImplementedError() new_physical = aten.hardswish_backward.default(grad_output.physical, self) - return StructuredSparseTensor(new_physical, grad_output.v_to_ps) + return StructuredSparseTensor(new_physical, grad_output.strides) diff --git a/src/torchjd/sparse/_aten_function_overrides/einsum.py b/src/torchjd/sparse/_aten_function_overrides/einsum.py index 89a9f32a4..45c6cd25d 100644 --- a/src/torchjd/sparse/_aten_function_overrides/einsum.py +++ b/src/torchjd/sparse/_aten_function_overrides/einsum.py @@ -5,7 +5,6 @@ from torchjd.sparse._structured_sparse_tensor import ( StructuredSparseTensor, impl, - p_to_vs_from_v_to_ps, to_most_efficient_tensor, to_structured_sparse_tensor, ) @@ -84,7 +83,7 @@ def group_indices(indices: list[tuple[int, int]]) -> None: ) else: indices_to_n_pdims[index] = len(pdims) - p_to_vs = p_to_vs_from_v_to_ps(t.v_to_ps) + p_to_vs = ... # p_to_vs_from_v_to_ps(t.v_to_ps) for indices_ in p_to_vs: # elements in indices[indices_] map to the same dimension, they should be clustered # together @@ -165,7 +164,7 @@ def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: @impl(aten.div.Tensor) def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: t1_, t2_ = prepare_for_elementwise_op(t1, t2) - t2_ = StructuredSparseTensor(1.0 / t2_.physical, t2_.v_to_ps) + t2_ = StructuredSparseTensor(1.0 / t2_.physical, t2_.strides) all_dims = list(range(t1_.ndim)) return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) @@ -177,7 +176,7 @@ def mul_Scalar(t: StructuredSparseTensor, scalar) -> StructuredSparseTensor: assert isinstance(t, StructuredSparseTensor) new_physical = aten.mul.Scalar(t.physical, scalar) - return StructuredSparseTensor(new_physical, t.v_to_ps) + return StructuredSparseTensor(new_physical, t.strides) @impl(aten.add.Tensor) @@ -186,9 +185,9 @@ def add_Tensor( ) -> StructuredSparseTensor: t1_, t2_ = prepare_for_elementwise_op(t1, t2) - if t1_.v_to_ps == t2_.v_to_ps: + if torch.equal(t1_.strides, t2_.strides): new_physical = t1_.physical + t2_.physical * alpha - return StructuredSparseTensor(new_physical, t1_.v_to_ps) + return StructuredSparseTensor(new_physical, t1_.strides) else: raise NotImplementedError() diff --git a/src/torchjd/sparse/_aten_function_overrides/pointwise.py b/src/torchjd/sparse/_aten_function_overrides/pointwise.py index 85798c65f..9d389c10b 100644 --- a/src/torchjd/sparse/_aten_function_overrides/pointwise.py +++ b/src/torchjd/sparse/_aten_function_overrides/pointwise.py @@ -71,7 +71,7 @@ def _override_pointwise(op): @impl(op) def func_(t: StructuredSparseTensor) -> StructuredSparseTensor: assert isinstance(t, StructuredSparseTensor) - return StructuredSparseTensor(op(t.physical), t.v_to_ps) + return StructuredSparseTensor(op(t.physical), t.strides) return func_ @@ -100,7 +100,7 @@ def pow_Tensor_Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredS return aten.pow.Tensor_Scalar(t.to_dense(), exponent) new_physical = aten.pow.Tensor_Scalar(t.physical, exponent) - return StructuredSparseTensor(new_physical, t.v_to_ps) + return StructuredSparseTensor(new_physical, t.strides) # Somehow there's no pow_.Tensor_Scalar and pow_.Scalar takes tensor and scalar. @@ -122,4 +122,4 @@ def div_Scalar(t: StructuredSparseTensor, divisor: float) -> StructuredSparseTen assert isinstance(t, StructuredSparseTensor) new_physical = aten.div.Scalar(t.physical, divisor) - return StructuredSparseTensor(new_physical, t.v_to_ps) + return StructuredSparseTensor(new_physical, t.strides) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py index 2c8c67706..0aa38c948 100644 --- a/src/torchjd/sparse/_aten_function_overrides/shape.py +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -2,13 +2,11 @@ from typing import cast import torch -from torch import Tensor, tensor +from torch import Tensor, arange, tensor from torch.ops import aten # type: ignore from torchjd.sparse._structured_sparse_tensor import ( StructuredSparseTensor, - encode_v_to_ps, - fix_dim_encoding, impl, print_fallback, to_most_efficient_tensor, @@ -136,10 +134,10 @@ def unsqueeze_default(t: StructuredSparseTensor, dim: int) -> StructuredSparseTe if dim < 0: dim = t.ndim + dim + 1 - new_v_to_ps = [p for p in t.v_to_ps] # Deepcopy the list to not modify the original v_to_ps - new_v_to_ps.insert(dim, []) - - return StructuredSparseTensor(t.physical, new_v_to_ps) + new_strides = torch.concatenate( + [t.strides[:dim], torch.zeros(1, t.strides.shape[1], dtype=torch.int64), t.strides[dim:]] + ) + return StructuredSparseTensor(t.physical, new_strides) @impl(aten.squeeze.dims) @@ -153,17 +151,15 @@ def squeeze_dims(t: StructuredSparseTensor, dims: list[int] | int | None) -> Ten else: excluded = set(dims) - new_v_to_ps = [pdims for i, pdims in enumerate(t.v_to_ps) if i not in excluded] - - return to_most_efficient_tensor(t.physical, new_v_to_ps) + is_row_kept = [i not in excluded for i in range(t.ndim)] + new_strides = t.strides[is_row_kept] + return to_most_efficient_tensor(t.physical, new_strides) @impl(aten.permute.default) def permute_default(t: StructuredSparseTensor, dims: list[int]) -> StructuredSparseTensor: - new_v_to_ps = [t.v_to_ps[d] for d in dims] - - new_physical, new_v_to_ps = fix_dim_encoding(t.physical, new_v_to_ps) - return StructuredSparseTensor(new_physical, new_v_to_ps) + new_strides = t.strides[torch.tensor(dims)] + return StructuredSparseTensor(t.physical, new_strides) @impl(aten.cat.default) @@ -197,25 +193,20 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor: # Add a physical dimension pdim on which we can concatenate the physicals such that this # translates into a concatenation of the virtuals on virtual dimension dim. - # Stride-based representation: - # new_stride_column = torch.zeros(ref_tensor.ndim, dtype=torch.int) - # new_stride_column[dim] = ref_virtual_dim_size - pdim = ref_tensor.physical.ndim - new_v_to_ps = [[d for d in pdims] for pdims in ref_tensor.v_to_ps] - new_v_to_ps[dim] = [pdim] + new_v_to_ps[dim] - new_v_to_ps, destination = encode_v_to_ps(new_v_to_ps) - source = list(range(len(destination))) - physicals = [t.physical.unsqueeze(-1).movedim(source, destination) for t in tensors_] + physicals = [t.physical.unsqueeze(-1) for t in tensors_] + new_stride_column = torch.zeros(ref_tensor.ndim, 1, dtype=torch.int64) + new_stride_column[dim, 0] = ref_virtual_dim_size + new_strides = torch.concatenate([ref_tensor.strides, new_stride_column], dim=1) else: # Such a physical dimension already exists. Note that an alternative implementation would be # to simply always add the physical dimension, and squash it if it ends up being not needed. physicals = [t.physical for t in tensors_] pdim = indices[0][0] - new_v_to_ps = ref_tensor.v_to_ps + new_strides = ref_tensor.strides new_physical = aten.cat.default(physicals, dim=pdim) - return StructuredSparseTensor(new_physical, new_v_to_ps) + return StructuredSparseTensor(new_physical, new_strides) @impl(aten.expand.default) @@ -225,32 +216,32 @@ def expand_default(t: StructuredSparseTensor, sizes: list[int]) -> StructuredSpa assert isinstance(sizes, list) assert len(sizes) >= t.ndim + # Add as many dimensions as needed at the beginning of the tensor (as torch.expand works) for _ in range(len(sizes) - t.ndim): t = t.unsqueeze(0) - assert len(sizes) == t.ndim - + # Try to expand each dimension to its new size new_physical = t.physical - new_v_to_ps = t.v_to_ps - n_added_physical_dims = 0 - for dim, (ps, orig_size, new_size) in enumerate(zip(t.v_to_ps, t.shape, sizes, strict=True)): - if len(ps) > 0 and orig_size != new_size and new_size != -1: + new_strides = t.strides + for d, (vstride, orig_size, new_size) in enumerate(zip(t.strides, t.shape, sizes, strict=True)): + if vstride.sum() > 0 and orig_size != new_size and new_size != -1: raise ValueError( - f"Cannot expand dim {dim} of size != 1. Found size {orig_size} and target size " + f"Cannot expand dim {d} of size != 1. Found size {orig_size} and target size " f"{new_size}." ) - if len(ps) == 0 and new_size != 1 and new_size != -1: + if vstride.sum() == 0 and new_size != 1 and new_size != -1: # Add a dimension of size new_size at the end of the physical tensor. new_physical_shape = list(new_physical.shape) + [new_size] new_physical = new_physical.unsqueeze(-1).expand(new_physical_shape) - new_v_to_ps[dim] = [t.physical.ndim + n_added_physical_dims] - n_added_physical_dims += 1 - new_v_to_ps, destination = encode_v_to_ps(new_v_to_ps) - new_physical = new_physical.movedim(list(range(len(destination))), destination) + # Make this new physical dimension have a stride of 1 at virtual dimension d and 0 at + # every other virtual dimension + new_stride_column = torch.zeros(t.ndim, 1, dtype=torch.int64) + new_stride_column[d, 0] = 1 + new_strides = torch.cat([new_strides, new_stride_column], dim=1) - return StructuredSparseTensor(new_physical, new_v_to_ps) + return StructuredSparseTensor(new_physical, new_strides) @impl(aten.broadcast_tensors.default) @@ -280,49 +271,14 @@ def broadcast_tensors_default(tensors: list[Tensor]) -> tuple[Tensor, Tensor]: return aten.expand.default(t1, new_shape), aten.expand.default(t2, new_shape) -@impl(aten.slice.Tensor) -def slice_Tensor( - t: StructuredSparseTensor, dim: int, start: int | None, end: int | None, step: int = 1 -) -> StructuredSparseTensor: - assert isinstance(t, StructuredSparseTensor) - - physical_dims = t.v_to_ps[dim] - - if len(physical_dims) > 1: - raise ValueError( - "Cannot yet slice virtual dim corresponding to several physical dims.\n" - f"{t.debug_info()}\n" - f"dim={dim}, start={start}, end={end}, step={step}." - ) - elif len(physical_dims) == 0: - # Trying to slice a virtual dim of size 1. - # Either - # - the element of this dim is included in the slice: keep it as it is - # - it's not included in the slice (e.g. end<=start): we would end up with a size of 0 on - # that dimension, so we'd need to add a dimension of size 0 to the physical. This is not - # implemented yet. - start_ = start if start is not None else 0 - end_ = end if end is not None else 1 - if end_ <= start_: # TODO: the condition might be a bit more complex if step != 1 - raise NotImplementedError( - "Slicing of dimension of size 1 leading to dimension of size 0 not implemented yet." - ) - else: - new_physical = t.physical - else: - physical_dim = physical_dims[0] - new_physical = aten.slice.Tensor(t.physical, physical_dim, start, end, step) - - return StructuredSparseTensor(new_physical, t.v_to_ps) - - @impl(aten.transpose.int) def transpose_int(t: StructuredSparseTensor, dim0: int, dim1: int) -> StructuredSparseTensor: assert isinstance(t, StructuredSparseTensor) + return StructuredSparseTensor(t.physical, _swap_rows(t.strides, dim0, dim1)) - new_v_to_ps = [dims for dims in t.v_to_ps] - new_v_to_ps[dim0] = t.v_to_ps[dim1] - new_v_to_ps[dim1] = t.v_to_ps[dim0] - new_physical, new_v_to_ps = fix_dim_encoding(t.physical, new_v_to_ps) - return StructuredSparseTensor(new_physical, new_v_to_ps) +def _swap_rows(matrix: Tensor, c0: int, c1: int) -> Tensor: + index = arange(matrix.shape[0]) + index[c0] = c1 + index[c1] = c0 + return matrix[index] diff --git a/src/torchjd/sparse/_structured_sparse_tensor.py b/src/torchjd/sparse/_structured_sparse_tensor.py index daaa0f902..a8641a1f5 100644 --- a/src/torchjd/sparse/_structured_sparse_tensor.py +++ b/src/torchjd/sparse/_structured_sparse_tensor.py @@ -6,14 +6,14 @@ import torch from torch import Tensor, arange, meshgrid, stack, tensor, tensordot, zeros -from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten +from torch.utils._pytree import tree_map class StructuredSparseTensor(Tensor): _HANDLED_FUNCTIONS = dict() @staticmethod - def __new__(cls, physical: Tensor, v_to_ps: list[list[int]]): + def __new__(cls, physical: Tensor, strides: Tensor): # At the moment, this class is not compositional, so we assert # that the tensor we're wrapping is exactly a Tensor assert type(physical) is Tensor @@ -28,20 +28,24 @@ def __new__(cls, physical: Tensor, v_to_ps: list[list[int]]): # (which is bad!) assert not physical.requires_grad or not torch.is_grad_enabled() - shape = [prod(physical.shape[i] for i in dims) for dims in v_to_ps] + pshape = torch.tensor(physical.shape) + vshape = strides @ (pshape - 1) + 1 return Tensor._make_wrapper_subclass( - cls, shape, dtype=physical.dtype, device=physical.device + cls, vshape, dtype=physical.dtype, device=physical.device ) - def __init__(self, physical: Tensor, v_to_ps: list[list[int]]): + def __init__(self, physical: Tensor, strides: Tensor): """ - This constructor is made for specifying physical and v_to_ps exactly. It should not modify + This constructor is made for specifying physical and strides exactly. It should not modify it. - For this reason, another constructor will be made to either modify the physical / v_to_ps to - simplify the result, or to create a dense tensor directly if it's already dense. It could - also be responsible for sorting the first apparition of each physical dim in the flattened - v_to_ps. + For this reason, another constructor will be made to either modify the physical / strides to + simplify the result, or to create a dense tensor directly if it's already dense. + + :param physical: The dense tensor holding the actual data. + :param strides: Integer (int64) tensor of shape [virtual_ndim, physical_ndim], representing + the linear transformation between an index in the physical tensor and the corresponding + index in the virtual tensor, i.e. v_index = strides @ p_index. """ if any(s == 1 for s in physical.shape): @@ -49,26 +53,25 @@ def __init__(self, physical: Tensor, v_to_ps: list[list[int]]): "physical must not contain any dimension of size 1. Found physical.shape=" f"{physical.shape}." ) - if not all(all(0 <= dim < physical.ndim for dim in dims) for dims in v_to_ps): + if strides.dtype is not torch.int64: raise ValueError( - f"Elements in v_to_ps must map to dimensions in physical. Found {v_to_ps}." + f"strides should be of int64 dtype. Found strides.dtype={strides.dtype}." ) - if len(set().union(*[set(dims) for dims in v_to_ps])) != physical.ndim: - raise ValueError("Every dimension in physical must appear at least once in v_to_ps.") - - if v_to_ps != encode_v_to_ps(v_to_ps)[0]: + if not (strides >= 0).all(): + raise ValueError(f"All strides must be non-negative. Found strides={strides}.") + if strides.shape[1] != physical.ndim: raise ValueError( - f"v_to_ps elements are not encoded by first appearance. Found {v_to_ps}." + f"strides should have 1 column per physical dimension. Found strides={strides} and physical.shape={physical.shape}." ) + if (strides.sum(dim=0) == 0).any(): + raise ValueError( + f"strides should not have any column full of zeros. Found strides={strides}." + ) + if any(len(group) != 1 for group in get_groupings(list(physical.shape), strides)): + raise ValueError(f"Dimensions must be maximally grouped. Found strides={strides}.") self.physical = physical - self.v_to_ps = v_to_ps - - # strides is of shape [v_ndim, p_ndim], such that v_index = strides @ p_index - self.strides = get_strides(list(self.physical.shape), v_to_ps) - - if any(len(group) != 1 for group in get_groupings(list(self.physical.shape), self.strides)): - raise ValueError(f"Dimensions must be maximally grouped. Found {v_to_ps}.") + self.strides = strides def to_dense( self, dtype: torch.dtype | None = None, *, masked_grad: bool | None = None @@ -101,16 +104,13 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): return func(*unwrapped_args, **unwrapped_kwargs) def __repr__(self, *, tensor_contents=None) -> str: - return f"StructuredSparseTensor(physical={self.physical}, v_to_ps={self.v_to_ps})" + return f"StructuredSparseTensor(physical={self.physical}, strides={self.strides})" def debug_info(self) -> str: info = ( - f"shape: {self.shape}\n" - f"stride(): {self.stride()}\n" - f"v_to_ps: {self.v_to_ps}\n" + f"vshape: {self.shape}\n" + f"pshape: {self.physical.shape}\n" f"strides: {self.strides}\n" - f"physical.shape: {self.physical.shape}\n" - f"physical.stride(): {self.physical.stride()}" ) return info @@ -131,9 +131,9 @@ def decorator(func): def print_fallback(func, args, kwargs) -> None: def tensor_to_str(t: Tensor) -> str: - result = f"{t.__class__.__name__} - shape: {t.shape}" + result = f"{t.__class__.__name__} - vshape: {t.shape}" if isinstance(t, StructuredSparseTensor): - result += f" - pshape: {t.physical.shape} - v_to_ps: {t.v_to_ps}" + result += f" - pshape: {t.physical.shape} - strides: {t.strides}" return result @@ -155,12 +155,6 @@ def tensor_to_str(t: Tensor) -> str: print() -def strides_from_p_dims_and_p_shape(p_dims: list[int], physical_shape: list[int]) -> list[int]: - return list(accumulate([1] + [physical_shape[dim] for dim in p_dims[:0:-1]], operator.mul))[ - ::-1 - ] - - def strides_v2(p_dims: list[int], physical_shape: list[int]) -> list[int]: """ From a list of physical dimensions corresponding to a virtual dimension, and from the physical @@ -182,85 +176,15 @@ def strides_v2(p_dims: list[int], physical_shape: list[int]) -> list[int]: 1 on physical dimension 1 makes you move by 1 on the virtual dimension. """ - strides_v1 = strides_from_p_dims_and_p_shape(p_dims, physical_shape) + strides_v1 = list(accumulate([1] + [physical_shape[d] for d in p_dims[:0:-1]], operator.mul))[ + ::-1 + ] result = [0 for _ in range(len(physical_shape))] for i, d in enumerate(p_dims): result[d] += strides_v1[i] return result -def get_strides(pshape: list[int], v_to_ps: list[list[int]]) -> Tensor: - strides = torch.tensor([strides_v2(pdims, pshape) for pdims in v_to_ps], dtype=torch.int64) - - # It's sometimes necessary to reshape: when v_to_ps contains 0 element for instance. - return strides.reshape(len(v_to_ps), len(pshape)) - - -def argmax(iterable): - return max(enumerate(iterable), key=lambda x: x[1])[0] - - -def strides_to_pdims(strides: Tensor, physical_shape: list[int]) -> list[int]: - """ - Given a list of strides, find and return the used physical dimensions. - - This algorithm runs in O(n * m) with n the number of physical dimensions (i.e. - len(physical_shape) and len(strides)), and with m the number of pdims in the result. - - I'm pretty sure it could be implemented in O((n+m)log(n)) by using a sorted linked list for the - remaining_strides, and keeping it sorted each time we update it. Argmax would then always be 0, - removing the need to go through the whole list at every iteration. - """ - - # e.g. strides = [22111, 201000], physical_shape = [10, 2] - - pdims = [] - remaining_strides = strides.clone() - remaining_numel = ( - sum(remaining_strides[i] * (physical_shape[i] - 1) for i in range(len(physical_shape))) + 1 - ) - # e.g. 9 * 22111 + 1 * 201000 + 1 = 400000 - - while sum(remaining_strides) > 0: - current_pdim = argmax(remaining_strides) - # e.g. 1 - - pdims.append(current_pdim) - - remaining_numel = remaining_numel // physical_shape[current_pdim] - # e.g. 400000 / 2 = 200000 - - remaining_strides[current_pdim] -= remaining_numel - # e.g. [22111, 1000] - - return pdims - - -def merge_strides(strides: list[list[int]]) -> list[int]: - return sorted({s for stride in strides for s in stride}, reverse=True) - - -def stride_to_shape(numel: int, stride: list[int]) -> list[int]: - augmented_stride = [numel] + stride - return [a // b for a, b in zip(augmented_stride[:-1], augmented_stride[1:])] - - -def p_to_vs_from_v_to_ps(v_to_ps: list[list[int]]) -> list[list[tuple[int, int]]]: - """ - A physical dimension is mapped to a list of couples of the form - (virtual_dim, sub_index_in_virtual_dim) - """ - - res = dict[int, list[tuple[int, int]]]() - for v_dim, p_dims in enumerate(v_to_ps): - for i, p_dim in enumerate(p_dims): - if p_dim not in res: - res[p_dim] = [(v_dim, i)] - else: - res[p_dim].append((v_dim, i)) - return [res[i] for i in range(len(res))] - - def get_groupings(pshape: list[int], strides: Tensor) -> list[list[int]]: strides_time_pshape = strides * tensor(pshape) groups = {i: {i} for i, column in enumerate(strides.T)} @@ -278,72 +202,38 @@ def get_groupings(pshape: list[int], strides: Tensor) -> list[list[int]]: return new_columns -def encode_by_order(input: list[int]) -> tuple[list[int], list[int]]: - """ - Encodes values based on the order of their first appearance, starting at 0 and incrementing. - - Returns the encoded list and the destination mapping each original int to its new encoding. - destination[i] = j means that all elements of value i in input are mapped to j in the encoded - list. - - The input list should only contain consecutive integers starting at 0. - - Examples: - [1, 0, 3, 2] => [0, 1, 2, 3], [1, 0, 3, 2] - [0, 2, 0, 1] => [0, 1, 0, 2], [0, 2, 1] - [1, 0, 0, 1] => [0, 1, 1, 0], [1, 0] - """ - - mapping = dict[int, int]() - curr = 0 - output = [] - for v in input: - if v not in mapping: - mapping[v] = curr - curr += 1 - output.append(mapping[v]) - destination = [mapping[i] for i in range(len(mapping))] - - return output, destination - - -def encode_v_to_ps(v_to_ps: list[list[int]]) -> tuple[list[list[int]], list[int]]: - flat_v_to_ps, spec = tree_flatten(v_to_ps) - sorted_flat_v_to_ps, destination = encode_by_order(flat_v_to_ps) - return tree_unflatten(sorted_flat_v_to_ps, spec), destination - - def to_structured_sparse_tensor(t: Tensor) -> StructuredSparseTensor: if isinstance(t, StructuredSparseTensor): return t else: - return make_sst(t, [[i] for i in range(t.ndim)]) - - -def to_most_efficient_tensor(physical: Tensor, v_to_ps: list[list[int]]) -> Tensor: - physical, v_to_ps = fix_dim_encoding(physical, v_to_ps) - physical, v_to_ps = fix_dim_of_size_1(physical, v_to_ps) - physical, v_to_ps = fix_ungrouped_dims(physical, v_to_ps) - - if sum([len(pdims) for pdims in v_to_ps]) == physical.ndim: - next_physical_index = physical.ndim - new_v_to_ps = [] - # Add as many dimensions of size 1 as there are pdims equal to [] in v_to_ps. - # Create the corresponding new_v_to_ps. - # E.g. if v_to_ps is [[0], [], [1]], new_v_to_ps is [[0], [2], [1]]. - for vdim, pdims in enumerate(v_to_ps): - if len(pdims) == 0: - physical = physical.unsqueeze(-1) - new_v_to_ps.append([next_physical_index]) - next_physical_index += 1 - else: - new_v_to_ps.append(pdims) - - return torch.movedim( - physical, list(range(physical.ndim)), [pdims[0] for pdims in new_v_to_ps] - ) + return make_sst(physical=t, strides=torch.eye(t.ndim, dtype=torch.int64)) + + +def to_most_efficient_tensor(physical: Tensor, strides: Tensor) -> Tensor: + physical, strides = fix_dim_of_size_1(physical, strides) + physical, strides = fix_ungrouped_dims(physical, strides) + + if (strides.sum(dim=0) == 1).all(): + # All physical dimensions make you move by 1 in exactly 1 virtual dimension. + # Also, because all physical dimensions have been maximally grouped, we cannot have two + # physical dimensions that make you move in the same virtual dimension. + # So strides is an identity matrix with potentially some extra rows of zeros, and + # potentially shuffled columns. + + # The first step is to unsqueeze the physical tensor for each extra row of zeros in the + # strides. + zero_row_mask = strides.sum(dim=1) == 0 + number_of_zero_rows = zero_row_mask.sum() + for _ in number_of_zero_rows: + physical = physical.unsqueeze(-1) + + # The second step is to re-order the physical dimensions so that the corresponding + # strides matrix would be an identity. + source = arange(strides.shape[0]) + destination = strides[zero_row_mask] @ source + return physical.movedim(list(source), list(destination)) else: - return StructuredSparseTensor(physical, v_to_ps) + return StructuredSparseTensor(physical, strides) def unwrap_to_dense(t: Tensor): @@ -353,57 +243,12 @@ def unwrap_to_dense(t: Tensor): return t -def to_target_physical_strides( - physical: Tensor, v_to_ps: list[list[int]], strides: list[list[int]] -) -> tuple[Tensor, list[list[int]]]: - current_strides = [ - strides_from_p_dims_and_p_shape(p_dims, list(physical.shape)) for p_dims in v_to_ps - ] - target_stride = merge_strides(strides) - - numel = physical.numel() - target_shape = stride_to_shape(numel, target_stride) - new_physical = physical.reshape(target_shape) - - stride_to_p_dim = {s: i for i, s in enumerate(target_stride)} - stride_to_p_dim[0] = len(target_shape) - - new_v_to_ps = list[list[int]]() - for stride in current_strides: - extended_stride = stride + [0] - new_p_dims = list[int]() - for s_curr, s_next in zip(extended_stride[:-1], extended_stride[1:]): - new_p_dims += range(stride_to_p_dim[s_curr], stride_to_p_dim[s_next]) - new_v_to_ps.append(new_p_dims) - - return new_physical, new_v_to_ps - - -def fix_dim_encoding(physical: Tensor, v_to_ps: list[list[int]]) -> tuple[Tensor, list[list[int]]]: - v_to_ps, destination = encode_v_to_ps(v_to_ps) - source = list(range(physical.ndim)) - physical = physical.movedim(source, destination) - - return physical, v_to_ps - - -def fix_dim_of_size_1(physical: Tensor, v_to_ps: list[list[int]]) -> tuple[Tensor, list[list[int]]]: - is_of_size_1 = [s == 1 for s in physical.shape] - - def new_encoding(d: int) -> int: - n_removed_dims_before_d = sum(is_of_size_1[:d]) - return d - n_removed_dims_before_d - - physical = physical.squeeze() - v_to_ps = [[new_encoding(d) for d in dims if not is_of_size_1[d]] for dims in v_to_ps] - - return physical, v_to_ps +def fix_dim_of_size_1(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tensor]: + is_of_size_1 = torch.tensor([s == 1 for s in physical.shape]) + return physical.squeeze(), strides[:, ~is_of_size_1] -def fix_ungrouped_dims( - physical: Tensor, v_to_ps: list[list[int]] -) -> tuple[Tensor, list[list[int]]]: - strides = get_strides(list(physical.shape), v_to_ps) +def fix_ungrouped_dims(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tensor]: groups = get_groupings(list(physical.shape), strides) nphysical = physical.reshape([prod([physical.shape[dim] for dim in group]) for group in groups]) stride_mapping = torch.zeros(physical.ndim, nphysical.ndim, dtype=torch.int64) @@ -411,17 +256,15 @@ def fix_ungrouped_dims( stride_mapping[group[-1], j] = 1 new_strides = strides @ stride_mapping - new_v_to_ps = [strides_to_pdims(stride, list(nphysical.shape)) for stride in new_strides] - return nphysical, new_v_to_ps + return nphysical, new_strides -def make_sst(physical: Tensor, v_to_ps: list[list[int]]) -> StructuredSparseTensor: - """Fix physical and v_to_ps and create a StructuredSparseTensor with them.""" +def make_sst(physical: Tensor, strides: Tensor) -> StructuredSparseTensor: + """Fix physical and strides and create a StructuredSparseTensor with them.""" - physical, v_to_ps = fix_dim_encoding(physical, v_to_ps) - physical, v_to_ps = fix_dim_of_size_1(physical, v_to_ps) - physical, v_to_ps = fix_ungrouped_dims(physical, v_to_ps) - return StructuredSparseTensor(physical, v_to_ps) + physical, strides = fix_dim_of_size_1(physical, strides) + physical, strides = fix_ungrouped_dims(physical, strides) + return StructuredSparseTensor(physical, strides) def fix_zero_stride_columns(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tensor]: From 7f2c9747d4e04e0131fbec98d8522a955cd72eb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 6 Nov 2025 15:53:41 +0100 Subject: [PATCH 173/182] Fix StructuredSparseTensor.__new__ --- src/torchjd/sparse/_structured_sparse_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_structured_sparse_tensor.py b/src/torchjd/sparse/_structured_sparse_tensor.py index a8641a1f5..528a8eb4b 100644 --- a/src/torchjd/sparse/_structured_sparse_tensor.py +++ b/src/torchjd/sparse/_structured_sparse_tensor.py @@ -31,7 +31,7 @@ def __new__(cls, physical: Tensor, strides: Tensor): pshape = torch.tensor(physical.shape) vshape = strides @ (pshape - 1) + 1 return Tensor._make_wrapper_subclass( - cls, vshape, dtype=physical.dtype, device=physical.device + cls, tuple(vshape.tolist()), dtype=physical.dtype, device=physical.device ) def __init__(self, physical: Tensor, strides: Tensor): From 51a579aa43cea7c4a8219066aa873e3ff3083930 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 6 Nov 2025 15:53:53 +0100 Subject: [PATCH 174/182] Fix tests --- .../sparse/test_structured_sparse_tensor.py | 212 ++++++++++++------ 1 file changed, 145 insertions(+), 67 deletions(-) diff --git a/tests/unit/sparse/test_structured_sparse_tensor.py b/tests/unit/sparse/test_structured_sparse_tensor.py index ea0055298..9bff7f98c 100644 --- a/tests/unit/sparse/test_structured_sparse_tensor.py +++ b/tests/unit/sparse/test_structured_sparse_tensor.py @@ -13,7 +13,6 @@ from torchjd.sparse._aten_function_overrides.shape import unsquash_pdim from torchjd.sparse._structured_sparse_tensor import ( StructuredSparseTensor, - encode_by_order, fix_ungrouped_dims, fix_zero_stride_columns, get_groupings, @@ -24,7 +23,7 @@ def test_to_dense(): n = 2 m = 3 a = randn_([n, m]) - b = StructuredSparseTensor(a, [[0], [1], [1], [0]]) + b = StructuredSparseTensor(a, tensor([[1, 0], [0, 1], [0, 1], [1, 0]])) c = b.to_dense() for i in range(n): @@ -34,32 +33,56 @@ def test_to_dense(): def test_to_dense2(): a = tensor_([1.0, 2.0, 3.0]) - b = StructuredSparseTensor(a, [[0, 0]]) + b = StructuredSparseTensor(a, tensor([[4]])) c = b.to_dense() expected = tensor_([1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]) assert torch.all(torch.eq(c, expected)) @mark.parametrize( - ["a_pshape", "a_v_to_ps", "b_pshape", "b_v_to_ps", "a_indices", "b_indices", "output_indices"], + ["a_pshape", "a_strides", "b_pshape", "b_strides", "a_indices", "b_indices", "output_indices"], [ - ([4, 5], [[0], [0], [1]], [4, 5], [[0], [1], [1]], [0, 1, 2], [0, 2, 3], [0, 1, 3]), - ([2, 3, 5], [[0, 1], [2, 0]], [10, 3], [[0], [1]], [0, 1], [1, 2], [0, 2]), - ([2, 3], [[0, 1]], [6], [[0]], [0], [0], []), - ([6, 2, 3], [[0], [1], [2]], [2, 3], [[0, 1], [0], [1]], [0, 1, 2], [0, 1, 2], [0, 1, 2]), + ( + [4, 5], + tensor([[1, 0], [1, 0], [0, 1]]), + [4, 5], + tensor([[1, 0], [0, 1], [0, 1]]), + [0, 1, 2], + [0, 2, 3], + [0, 1, 3], + ), + ( + [2, 3, 5], + tensor([[3, 1, 0], [1, 0, 2]]), + [10, 3], + tensor([[1, 0], [0, 1]]), + [0, 1], + [1, 2], + [0, 2], + ), + ([2, 3], tensor([[3, 1]]), [6], tensor([[1]]), [0], [0], []), + ( + [6, 2, 3], + tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + [2, 3], + tensor([[3, 1], [1, 0], [0, 1]]), + [0, 1, 2], + [0, 1, 2], + [0, 1, 2], + ), ], ) def test_einsum( a_pshape: list[int], - a_v_to_ps: list[list[int]], + a_strides: Tensor, b_pshape: list[int], - b_v_to_ps: list[list[int]], + b_strides: Tensor, a_indices: list[int], b_indices: list[int], output_indices: list[int], ): - a = StructuredSparseTensor(randn_(a_pshape), a_v_to_ps) - b = StructuredSparseTensor(randn_(b_pshape), b_v_to_ps) + a = StructuredSparseTensor(randn_(a_pshape), a_strides) + b = StructuredSparseTensor(randn_(b_pshape), b_strides) res = einsum((a, a_indices), (b, b_indices), output=output_indices) @@ -80,7 +103,7 @@ def test_einsum( ) def test_structured_sparse_tensor_scalar(shape: list[int]): a = randn_(shape) - b = StructuredSparseTensor(a, [[dim] for dim in range(len(shape))]) + b = StructuredSparseTensor(a, torch.eye(len(shape), dtype=torch.int64)) assert_close(a, b.to_dense()) @@ -88,7 +111,7 @@ def test_structured_sparse_tensor_scalar(shape: list[int]): @mark.parametrize("dim", [2, 3, 4, 5, 10]) def test_diag_equivalence(dim: int): a = randn_([dim]) - b = StructuredSparseTensor(a, [[0], [0]]) + b = StructuredSparseTensor(a, tensor([[1], [1]])) diag_a = torch.diag(a) @@ -98,7 +121,7 @@ def test_diag_equivalence(dim: int): def test_three_virtual_single_physical(): dim = 10 a = randn_([dim]) - b = StructuredSparseTensor(a, [[0], [0], [0]]) + b = StructuredSparseTensor(a, tensor([[1], [1], [1]])) expected = zeros_([dim, dim, dim]) for i in range(dim): @@ -111,7 +134,7 @@ def test_three_virtual_single_physical(): def test_pointwise(func): dim = 10 a = randn_([dim]) - b = StructuredSparseTensor(a, [[0], [0]]) + b = StructuredSparseTensor(a, tensor([[1], [1]])) c = b.to_dense() res = func(b) assert isinstance(res, StructuredSparseTensor) @@ -123,7 +146,7 @@ def test_pointwise(func): def test_inplace_pointwise(func): dim = 10 a = randn_([dim]) - b = StructuredSparseTensor(a, [[0], [0]]) + b = StructuredSparseTensor(a, tensor([[1], [1]])) c = b.to_dense() func(b) assert isinstance(b, StructuredSparseTensor) @@ -135,7 +158,7 @@ def test_inplace_pointwise(func): def test_unary(func): dim = 10 a = randn_([dim]) - b = StructuredSparseTensor(a, [[0], [0]]) + b = StructuredSparseTensor(a, tensor([[1], [1]])) c = b.to_dense() res = func(b) @@ -143,61 +166,106 @@ def test_unary(func): @mark.parametrize( - ["physical_shape", "v_to_ps", "target_shape", "expected_physical_shape", "expected_v_to_ps"], + ["physical_shape", "strides", "target_shape", "expected_physical_shape", "expected_strides"], [ - ([2, 3], [[0], [0], [1]], [2, 2, 3], [2, 3], [[0], [0], [1]]), # no change of shape - ([2, 3], [[0], [0, 1]], [2, 6], [2, 3], [[0], [0, 1]]), # no change of shape - ([2, 3], [[0], [0], [1]], [2, 6], [2, 3], [[0], [0, 1]]), # squashing 2 dims - ([2, 3], [[0], [0, 1]], [2, 2, 3], [2, 3], [[0], [0], [1]]), # unsquashing into 2 dims - ([2, 3], [[0, 0, 1]], [2, 6], [2, 3], [[0], [0, 1]]), # unsquashing into 2 dims - ([2, 3], [[0], [0], [1]], [12], [2, 3], [[0, 0, 1]]), # squashing 3 dims - ([2, 3], [[0, 0, 1]], [2, 2, 3], [2, 3], [[0], [0], [1]]), # unsquashing into 3 dims - ([4], [[0], [0]], [2, 2, 4], [2, 2], [[0], [1], [0, 1]]), # unsquashing physical dim - ([4], [[0], [0]], [4, 2, 2], [2, 2], [[0, 1], [0], [1]]), # unsquashing physical dim - ([2, 3, 4], [[0], [0], [1], [2]], [4, 12], [2, 12], [[0, 0], [1]]), # world boss - ([2, 12], [[0, 0], [1]], [2, 2, 3, 4], [2, 3, 4], [[0], [0], [1], [2]]), # world boss + ( + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + [2, 2, 3], + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + ), # no change of shape + ( + [2, 3], + tensor([[1, 0], [3, 1]]), + [2, 6], + [2, 3], + tensor([[1, 0], [3, 1]]), + ), # no change of shape + ( + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + [2, 6], + [2, 3], + tensor([[1, 0], [3, 1]]), + ), # squashing 2 dims + ( + [2, 3], + tensor([[1, 0], [3, 1]]), + [2, 2, 3], + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + ), # unsquashing into 2 dims + ( + [2, 3], + tensor([[9, 1]]), + [2, 6], + [2, 3], + tensor([[1, 0], [3, 1]]), + ), # unsquashing into 2 dims + ( + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + [12], + [2, 3], + tensor([[9, 1]]), + ), # squashing 3 dims + ( + [2, 3], + tensor([[9, 1]]), + [2, 2, 3], + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + ), # unsquashing into 3 dims + ( + [4], + tensor([[1], [1]]), + [2, 2, 4], + [2, 2], + tensor([[1, 0], [0, 1], [2, 1]]), + ), # unsquashing physical dim + ( + [4], + tensor([[1], [1]]), + [4, 2, 2], + [2, 2], + tensor([[2, 1], [1, 0], [0, 1]]), + ), # unsquashing physical dim + ( + [2, 3, 4], + tensor([[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]), + [4, 12], + [2, 12], + tensor([[3, 0], [0, 1]]), + ), # world boss + ( + [2, 12], + tensor([[3, 0], [0, 1]]), + [2, 2, 3, 4], + [2, 3, 4], + tensor([[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]), + ), # world boss ], ) def test_view( physical_shape: list[int], - v_to_ps: list[list[int]], + strides: Tensor, target_shape: list[int], expected_physical_shape: list[int], - expected_v_to_ps: list[list[int]], + expected_strides: Tensor, ): a = randn_(tuple(physical_shape)) - t = StructuredSparseTensor(a, v_to_ps) + t = StructuredSparseTensor(a, strides) result = aten.view.default(t, target_shape) expected = t.to_dense().reshape(target_shape) assert isinstance(result, StructuredSparseTensor) assert list(result.physical.shape) == expected_physical_shape - assert result.v_to_ps == expected_v_to_ps + assert torch.equal(result.strides, expected_strides) assert torch.all(torch.eq(result.to_dense(), expected)) -@mark.parametrize( - ["input", "expected_output", "expected_destination"], - [ - ([0, 1, 0, 2, 1, 3], [0, 1, 0, 2, 1, 3], [0, 1, 2, 3]), # trivial - ([1, 0, 3, 2, 1], [0, 1, 2, 3, 0], [1, 0, 3, 2]), - ([1, 0, 3, 2], [0, 1, 2, 3], [1, 0, 3, 2]), - ([0, 2, 0, 1], [0, 1, 0, 2], [0, 2, 1]), - ([1, 0, 0, 1], [0, 1, 1, 0], [1, 0]), - ], -) -def test_encode_by_order( - input: list[int], - expected_output: list[int], - expected_destination: list[int], -): - output, destination = encode_by_order(input) - - assert output == expected_output - assert destination == expected_destination - - @mark.parametrize( ["pshape", "strides", "expected"], [ @@ -214,24 +282,34 @@ def test_get_groupings(pshape: list[int], strides: torch.Tensor, expected: list[ @mark.parametrize( - ["physical_shape", "v_to_ps", "expected_physical_shape", "expected_v_to_ps"], + ["physical_shape", "strides", "expected_physical_shape", "expected_strides"], [ - ([3, 4, 5], [[0, 1, 2], [2, 0, 1], [2]], [12, 5], [[0, 1], [1, 0], [1]]), - ([32, 20, 8], [[0], [1, 0], [2]], [32, 20, 8], [[0], [1, 0], [2]]), - ([3, 3, 4], [[0, 1], [1, 2]], [3, 3, 4], [[0, 1], [1, 2]]), + ( + [3, 4, 5], + tensor([[20, 5, 1], [4, 1, 12], [0, 0, 1]]), + [12, 5], + tensor([[5, 1], [1, 12], [0, 1]]), + ), + ( + [32, 20, 8], + tensor([[1, 0, 0], [1, 32, 0], [0, 0, 1]]), + [32, 20, 8], + tensor([[1, 0, 0], [1, 32, 0], [0, 0, 1]]), + ), + ([3, 3, 4], tensor([[3, 1, 0], [0, 4, 1]]), [3, 3, 4], tensor([[3, 1, 0], [0, 4, 1]])), ], ) def test_fix_ungrouped_dims( physical_shape: list[int], - v_to_ps: list[list[int]], + strides: Tensor, expected_physical_shape: list[int], - expected_v_to_ps: list[list[int]], + expected_strides: Tensor, ): physical = randn_(physical_shape) - fixed_physical, fixed_v_to_ps = fix_ungrouped_dims(physical, v_to_ps) + fixed_physical, fixed_strides = fix_ungrouped_dims(physical, strides) assert list(fixed_physical.shape) == expected_physical_shape - assert fixed_v_to_ps == expected_v_to_ps + assert torch.equal(fixed_strides, expected_strides) @mark.parametrize( @@ -265,15 +343,15 @@ def test_unsquash_pdim( @mark.parametrize( ["sst_args", "dim"], [ - ([([3], [[0], [0]]), ([3], [[0], [0]])], 1), - ([([3, 2], [[0], [1, 0]]), ([3, 2], [[0], [1, 0]])], 1), + ([([3], tensor([[1], [1]])), ([3], tensor([[1], [1]]))], 1), + ([([3, 2], tensor([[1, 0], [1, 3]])), ([3, 2], tensor([[1, 0], [1, 3]]))], 1), ], ) def test_concatenate( - sst_args: list[tuple[list[int], list[list[int]]]], + sst_args: list[tuple[list[int], Tensor]], dim: int, ): - tensors = [StructuredSparseTensor(randn_(pshape), v_to_ps) for pshape, v_to_ps in sst_args] + tensors = [StructuredSparseTensor(randn_(pshape), strides) for pshape, strides in sst_args] res = aten.cat.default(tensors, dim) expected = aten.cat.default([t.to_dense() for t in tensors], dim) From 6e94a89a2d295292621165c47d427b501d74caab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 7 Nov 2025 17:56:38 +0100 Subject: [PATCH 175/182] Add initial implementation of view_default --- .../sparse/_aten_function_overrides/shape.py | 95 ++++++++----------- 1 file changed, 42 insertions(+), 53 deletions(-) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py index 0aa38c948..4023199af 100644 --- a/src/torchjd/sparse/_aten_function_overrides/shape.py +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -1,3 +1,5 @@ +import operator +from itertools import accumulate from math import prod from typing import cast @@ -16,6 +18,27 @@ @impl(aten.view.default) def view_default(t: StructuredSparseTensor, shape: list[int]) -> Tensor: + """ + The main condition that we want to respect is that the indexing in the flattened virtual + tensor should remain the same before and after the reshape, i.e. + + c.T S = c'.T S' (1) + where: + * c is the reversed vector of cumulative physical shape before the reshape, i.e. + c.T = [prod(t.shape[1:]), prod(t.shape[2:]), ..., t.shape[-1], 1] + * c' is the same thing but after the reshape, i.e. + c'.T = [prod(shape[1:]), prod(shape[2:]), ..., shape[-1], 1] + * S is the original matrix of strides (t.strides) + * S' is the matrix of strides after reshaping. + + For u, v in Z^m and c in Z, say that u ≡ v (mod c) if u_i ≡ v_i (mod c) for all i. + Note that c'.T S' ≡ S'[-1] (mod shape[-1]) + So if we set S'[-1] = c.T S % shape[-1], we have c.T S ≡ c'.T S' (mod shape[-1]) + + (c'.T S' - S'[-1]) // shape[-1] ≡ S'[-1] (mod shape[-1]) + ... + """ + assert isinstance(t, StructuredSparseTensor) shape = infer_shape(shape, t.numel()) @@ -23,59 +46,25 @@ def view_default(t: StructuredSparseTensor, shape: list[int]) -> Tensor: if prod(shape) != t.numel(): raise ValueError(f"shape '{shape}' is invalid for input of size {t.numel()}") - new_v_to_ps = [] - idx = 0 - flat_v_to_ps = [dim for dims in t.v_to_ps for dim in dims] - new_physical = t.physical - for s in shape: - group = [] - current_size = 1 - - while current_size < s: - if idx >= len(flat_v_to_ps): - # TODO: I don't think this can happen, need to review and remove if I'm right. - raise ValueError() - - pdim = flat_v_to_ps[idx] - pdim_size = new_physical.shape[pdim] - - if current_size * pdim_size > s: - # Need to split physical dimension - if s % current_size != 0: - raise ValueError("Can't split physical dimension") - - new_pdim_first_dim_size = s // current_size - - if pdim_size % new_pdim_first_dim_size != 0: - raise ValueError("Can't split physical dimension") - - new_pdim_shape = [new_pdim_first_dim_size, pdim_size // new_pdim_first_dim_size] - new_physical, new_encoding = unsquash_pdim(new_physical, pdim, new_pdim_shape) - - new_v_to_ps = [ - [new_d for d in dims for new_d in new_encoding[d]] for dims in new_v_to_ps - ] - # A bit of a weird trick here. We want to re-encode flat_v_to_ps according to - # new_encoding. However, re-encoding elements before idx would potentially change - # the length of the list before idx, so idx would not have the right value anymore. - # Since we don't need the elements of flat_v_to_ps that are before idx anyway, we - # just get rid of them and re-encode flat_v_to_ps[idx:] instead, and reset idx to 0 - # to say that we're back at the beginning of this new list. - flat_v_to_ps = [new_d for d in flat_v_to_ps[idx:] for new_d in new_encoding[d]] - idx = 0 - - group.append(pdim) - current_size *= new_physical.shape[pdim] - idx += 1 - - new_v_to_ps.append(group) - - if idx != len(flat_v_to_ps): - raise ValueError(f"idx != len(flat_v_to_ps). {idx}; {flat_v_to_ps}; {shape}; {t.v_to_ps}") - - # The above code does not handle physical dimension squashing, so the physical is not - # necessarily maximally squashed at this point, so we need the safe constructor. - return to_most_efficient_tensor(new_physical, new_v_to_ps) + S = t.strides + vshape = list(t.shape) + c = _reverse_cumulative_product(vshape) + remaining_cT_S = c @ S + + stride_rows = list[Tensor]() + for modulo in shape[::-1]: + stride_row = remaining_cT_S % modulo + stride_rows.append(stride_row) + remaining_cT_S = (remaining_cT_S - stride_row) // modulo + # I think we could skip the - stride_row because the floor div will handle it for us, but it + # will make code harder to understand. + + new_strides = torch.stack(stride_rows, dim=0) + return to_most_efficient_tensor(t.physical, new_strides) + + +def _reverse_cumulative_product(values: list[int]) -> Tensor: + return tensor(list(accumulate((values[1:] + [1])[::-1], operator.mul))[::-1]) def infer_shape(shape: list[int], numel: int) -> list[int]: From fc3339c0dd8dad1ca80f3163123bdcd12da5a02e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 7 Nov 2025 18:07:32 +0100 Subject: [PATCH 176/182] Fix order of rows of new_strides in view_default --- src/torchjd/sparse/_aten_function_overrides/shape.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py index 4023199af..20f5a0542 100644 --- a/src/torchjd/sparse/_aten_function_overrides/shape.py +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -59,7 +59,7 @@ def view_default(t: StructuredSparseTensor, shape: list[int]) -> Tensor: # I think we could skip the - stride_row because the floor div will handle it for us, but it # will make code harder to understand. - new_strides = torch.stack(stride_rows, dim=0) + new_strides = torch.stack(stride_rows[::-1], dim=0) return to_most_efficient_tensor(t.physical, new_strides) From fac9c72fae53a5253b15a73834eb1c809869faa8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 7 Nov 2025 18:19:52 +0100 Subject: [PATCH 177/182] Fix creation of SST in autogram --- src/torchjd/autogram/_engine.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index dafe362ec..1db1b6d47 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -175,8 +175,9 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: ) output_dims = list(range(output.ndim)) - v_to_ps = [[dim] for dim in output_dims * 2] - jac_output = make_sst(torch.ones_like(output), v_to_ps) + identity = torch.eye(output.ndim, dtype=torch.int64) + strides = torch.concatenate([identity, identity.clone()], dim=0) + jac_output = make_sst(torch.ones_like(output), strides) vmapped_diff = differentiation for _ in output_dims: From 66c221026cd77098c1fbbb7925d30f604a71a498 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 7 Nov 2025 18:55:01 +0100 Subject: [PATCH 178/182] One-line view_default. The trick is that remaining_cT_S did not really depend on stride_row, because the floor_division would give the same result if stride_row was not removed. So remaining_cT_S can be precomputed by pre-dividing. Similarly, the modulo can be done in parallel. So there's no need for any for-loop anymore, all can be computed at once. --- .../sparse/_aten_function_overrides/shape.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py index 20f5a0542..d10715150 100644 --- a/src/torchjd/sparse/_aten_function_overrides/shape.py +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -49,17 +49,8 @@ def view_default(t: StructuredSparseTensor, shape: list[int]) -> Tensor: S = t.strides vshape = list(t.shape) c = _reverse_cumulative_product(vshape) - remaining_cT_S = c @ S - - stride_rows = list[Tensor]() - for modulo in shape[::-1]: - stride_row = remaining_cT_S % modulo - stride_rows.append(stride_row) - remaining_cT_S = (remaining_cT_S - stride_row) // modulo - # I think we could skip the - stride_row because the floor div will handle it for us, but it - # will make code harder to understand. - - new_strides = torch.stack(stride_rows[::-1], dim=0) + c_prime = _reverse_cumulative_product(shape) + new_strides = ((c @ S).unsqueeze(0) // c_prime.unsqueeze(1)) % tensor(shape).unsqueeze(1) return to_most_efficient_tensor(t.physical, new_strides) From 5cbab0802fa01d6dfb3c117b60c09df225571dbe Mon Sep 17 00:00:00 2001 From: Matthieu Buot de l'Epine Date: Wed, 12 Nov 2025 18:03:04 +0100 Subject: [PATCH 179/182] update interface of unsquash_pdim --- .../sparse/_aten_function_overrides/shape.py | 69 +++++++++++++------ .../sparse/test_structured_sparse_tensor.py | 23 +++++-- 2 files changed, 63 insertions(+), 29 deletions(-) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py index d10715150..a4c255607 100644 --- a/src/torchjd/sparse/_aten_function_overrides/shape.py +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -69,34 +69,59 @@ def infer_shape(shape: list[int], numel: int) -> list[int]: return [inferred if s == -1 else s for s in shape] -def unsquash_pdim_from_strides( - physical: Tensor, pdim: int, new_pdim_shape: list[int] +def unsquash_pdim( + physical: Tensor, strides: Tensor, pdim: int, new_pdim_shape: list[int] ) -> tuple[Tensor, Tensor]: - new_shape = list(physical.shape) - new_shape = new_shape[:pdim] + new_pdim_shape + new_shape[pdim + 1 :] - new_physical = physical.reshape(new_shape) - - stride_multipliers = tensor([prod(new_pdim_shape[i + 1 :]) for i in range(len(new_pdim_shape))]) - return new_physical, stride_multipliers + """ + EXAMPLE: + + physical = [ + [1, 2, 3, 4, 5, 6], + [7, 8, 9, 10, 11, 12], + [13, 14, 15, 16, 17, 18], + ] + strides = [ + [1, 1], + [0, 2], + ] + + dim = 1 + shape = [2, 3] + + new_physical = [[ + [1, 2, 3], + [4, 5, 6], + ], [ + [7, 8, 9], + [10, 11, 12], + ], [ + [13, 14, 15], + [16, 17, 18], + ]] + + new_strides = [ + [1, 3, 1], + [0, 6, 2] + """ + # TODO: handle working with multiple dimensions at once -def unsquash_pdim( - physical: Tensor, pdim: int, new_pdim_shape: list[int] -) -> tuple[Tensor, list[list[int]]]: - new_shape = list(physical.shape) - new_shape = new_shape[:pdim] + new_pdim_shape + new_shape[pdim + 1 :] + old_shape = list(physical.shape) + new_shape = old_shape[:pdim] + new_pdim_shape + old_shape[pdim + 1 :] new_physical = physical.reshape(new_shape) - def new_encoding_fn(d: int) -> list[int]: - if d < pdim: - return [d] - elif d > pdim: - return [d + len(new_pdim_shape) - 1] - else: - return [pdim + i for i in range(len(new_pdim_shape))] + stride_multipliers = tensor([prod(new_pdim_shape[i + 1 :]) for i in range(len(new_pdim_shape))]) + + new_strides = torch.concat( + [ + strides[:, :pdim], + torch.outer(strides[:, pdim], stride_multipliers), + strides[:, pdim + 1 :], + ], + dim=1, + ) - new_encoding = [new_encoding_fn(d) for d in range(len(physical.shape))] - return new_physical, new_encoding + return new_physical, new_strides @impl(aten._unsafe_view.default) diff --git a/tests/unit/sparse/test_structured_sparse_tensor.py b/tests/unit/sparse/test_structured_sparse_tensor.py index 9bff7f98c..414410269 100644 --- a/tests/unit/sparse/test_structured_sparse_tensor.py +++ b/tests/unit/sparse/test_structured_sparse_tensor.py @@ -315,29 +315,38 @@ def test_fix_ungrouped_dims( @mark.parametrize( [ "physical_shape", + "strides", "pdim", "new_pdim_shape", "expected_physical_shape", - "expected_new_encoding", + "expected_strides", ], [ - ([4], 0, [4], [4], [[0]]), # trivial - ([4], 0, [2, 2], [2, 2], [[0, 1]]), - ([3, 4, 5], 1, [2, 1, 1, 2], [3, 2, 1, 1, 2, 5], [[0], [1, 2, 3, 4], [5]]), + ([4], tensor([[1], [2]]), 0, [4], [4], tensor([[1], [2]])), # trivial + ([4], tensor([[1], [2]]), 0, [2, 2], [2, 2], tensor([[2, 1], [4, 2]])), + ( + [3, 4, 5], + tensor([[1, 2, 0], [1, 0, 1], [0, 1, 1]]), + 1, + [2, 1, 1, 2], + [3, 2, 1, 1, 2, 5], + tensor([[1, 4, 4, 4, 2, 0], [1, 0, 0, 0, 0, 1], [0, 2, 2, 2, 1, 1]]), + ), ], ) def test_unsquash_pdim( physical_shape: list[int], + strides: Tensor, pdim: int, new_pdim_shape: list[int], expected_physical_shape: list[int], - expected_new_encoding: list[list[int]], + expected_strides: Tensor, ): physical = randn_(physical_shape) - new_physical, new_encoding = unsquash_pdim(physical, pdim, new_pdim_shape) + new_physical, new_strides = unsquash_pdim(physical, strides, pdim, new_pdim_shape) assert list(new_physical.shape) == expected_physical_shape - assert new_encoding == expected_new_encoding + assert torch.equal(new_strides, expected_strides) @mark.parametrize( From a5c8cbfc8c9add4e3ac2619cb376b81ffab934b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 12 Nov 2025 19:52:37 +0100 Subject: [PATCH 180/182] Add get_full_source --- .../sparse/_structured_sparse_tensor.py | 37 +++++++++++++++++++ .../sparse/test_structured_sparse_tensor.py | 21 +++++++++++ 2 files changed, 58 insertions(+) diff --git a/src/torchjd/sparse/_structured_sparse_tensor.py b/src/torchjd/sparse/_structured_sparse_tensor.py index 528a8eb4b..2ba9b132a 100644 --- a/src/torchjd/sparse/_structured_sparse_tensor.py +++ b/src/torchjd/sparse/_structured_sparse_tensor.py @@ -243,6 +243,43 @@ def unwrap_to_dense(t: Tensor): return t +def get_full_source(source: list[int], destination: list[int], ndim: int) -> list[int]: + """ + Doing a movedim with source and destination is always equivalent to doing a movedim with + [0, 1, ..., ndim-1] (aka "full_destination") as destination, and the "full_source" as source. + + This function computes the full_source based on a source and destination. + + Example: + source=[2, 4] + destination=[0, 3] + ndim=5 + + full_source = [2, 0, 1, 4, 3] + full_destination = [0, 1, 2, 3, 4] + """ + + idx = torch.full((ndim,), -1, dtype=torch.int64) + idx[destination] = tensor(source) + source_set = set(source) + idx[idx.eq(-1)] = tensor([i for i in range(ndim) if i not in source_set]) + + # source_mask = torch.zeros(ndim, dtype=torch.bool) + # destination_mask = torch.zeros(ndim, dtype=torch.bool) + # source_mask[source] = True + # destination_mask[destination] = True + # + # destination_cumsum = torch.cumsum(destination_mask, dim=0) + # source_cumsum = torch.cumsum(source_mask, dim=0) + # base = arange(ndim, dtype=torch.int64) + # + # idx = torch.empty((ndim,), dtype=torch.int64) + # idx[destination_mask] = tensor(source) + # idx[~destination_mask] = base[~destination_mask] - destination_cumsum[~destination_mask] + source_cumsum[:ndim - len(source)] + + return idx.tolist() + + def fix_dim_of_size_1(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tensor]: is_of_size_1 = torch.tensor([s == 1 for s in physical.shape]) return physical.squeeze(), strides[:, ~is_of_size_1] diff --git a/tests/unit/sparse/test_structured_sparse_tensor.py b/tests/unit/sparse/test_structured_sparse_tensor.py index 414410269..00f6112da 100644 --- a/tests/unit/sparse/test_structured_sparse_tensor.py +++ b/tests/unit/sparse/test_structured_sparse_tensor.py @@ -15,6 +15,7 @@ StructuredSparseTensor, fix_ungrouped_dims, fix_zero_stride_columns, + get_full_source, get_groupings, ) @@ -349,6 +350,26 @@ def test_unsquash_pdim( assert torch.equal(new_strides, expected_strides) +@mark.parametrize( + [ + "source", + "destination", + "ndim", + ], + [ + ([2, 4], [0, 3], 5), + ([5, 3, 6], [2, 0, 5], 8), + ], +) +def test_get_column_indices(source: list[int], destination: list[int], ndim: int): + # TODO: this test should be improved / removed. It creates quite big tensors for nothing. + + t = randn_(list(torch.randint(3, 8, size=(ndim,)))) + full_destination = list(range(ndim)) + full_source = get_full_source(source, destination, ndim) + assert torch.equal(t.movedim(full_source, full_destination), t.movedim(source, destination)) + + @mark.parametrize( ["sst_args", "dim"], [ From c1fc11516fab4dddb1076eefd42f9b77796909da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 12 Nov 2025 19:52:53 +0100 Subject: [PATCH 181/182] Remove alternative implementation --- src/torchjd/sparse/_structured_sparse_tensor.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/torchjd/sparse/_structured_sparse_tensor.py b/src/torchjd/sparse/_structured_sparse_tensor.py index 2ba9b132a..11ad01b2a 100644 --- a/src/torchjd/sparse/_structured_sparse_tensor.py +++ b/src/torchjd/sparse/_structured_sparse_tensor.py @@ -264,19 +264,6 @@ def get_full_source(source: list[int], destination: list[int], ndim: int) -> lis source_set = set(source) idx[idx.eq(-1)] = tensor([i for i in range(ndim) if i not in source_set]) - # source_mask = torch.zeros(ndim, dtype=torch.bool) - # destination_mask = torch.zeros(ndim, dtype=torch.bool) - # source_mask[source] = True - # destination_mask[destination] = True - # - # destination_cumsum = torch.cumsum(destination_mask, dim=0) - # source_cumsum = torch.cumsum(source_mask, dim=0) - # base = arange(ndim, dtype=torch.int64) - # - # idx = torch.empty((ndim,), dtype=torch.int64) - # idx[destination_mask] = tensor(source) - # idx[~destination_mask] = base[~destination_mask] - destination_cumsum[~destination_mask] + source_cumsum[:ndim - len(source)] - return idx.tolist() From bb8faaf0090365c842087af8573598e0a1cdda7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 17 Nov 2025 16:57:37 +0100 Subject: [PATCH 182/182] Remove useless clone --- src/torchjd/autogram/_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 1db1b6d47..964b94a67 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -176,7 +176,7 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: output_dims = list(range(output.ndim)) identity = torch.eye(output.ndim, dtype=torch.int64) - strides = torch.concatenate([identity, identity.clone()], dim=0) + strides = torch.concatenate([identity, identity], dim=0) jac_output = make_sst(torch.ones_like(output), strides) vmapped_diff = differentiation