From d4e1098f475220d7ff3ea127d57b3eddfc721731 Mon Sep 17 00:00:00 2001 From: Alessandro Gentili Date: Sat, 28 Feb 2026 09:59:44 +0100 Subject: [PATCH] extending inv_as_solve to batched/Blockwise operations --- pytensor/tensor/rewriting/linalg.py | 37 +++++++++++++-------------- tests/tensor/rewriting/test_linalg.py | 29 +++++++++++++++++++++ 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 2c17020cd9..5bd5fa4c0e 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -132,29 +132,28 @@ def transinv_to_invtrans(fgraph, node): @register_stabilize -@node_rewriter([Dot]) +@node_rewriter([Dot, _matmul]) def inv_as_solve(fgraph, node): """ This utilizes a boolean `symmetric` tag on the matrices. """ - if isinstance(node.op, Dot): - l, r = node.inputs - if ( - l.owner - and isinstance(l.owner.op, Blockwise) - and isinstance(l.owner.op.core_op, MatrixInverse) - ): - return [solve(l.owner.inputs[0], r)] - if ( - r.owner - and isinstance(r.owner.op, Blockwise) - and isinstance(r.owner.op.core_op, MatrixInverse) - ): - x = r.owner.inputs[0] - if getattr(x.tag, "symmetric", None) is True: - return [solve(x, (l.mT)).mT] - else: - return [solve((x.mT), (l.mT)).mT] + l, r = node.inputs + if ( + l.owner + and isinstance(l.owner.op, Blockwise) + and isinstance(l.owner.op.core_op, MatrixInverse) + ): + return [solve(l.owner.inputs[0], r)] + if ( + r.owner + and isinstance(r.owner.op, Blockwise) + and isinstance(r.owner.op.core_op, MatrixInverse) + ): + x = r.owner.inputs[0] + if getattr(x.tag, "symmetric", None) is True: + return [solve(x, (l.mT)).mT] + else: + return [solve((x.mT), (l.mT)).mT] @register_stabilize diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 9e8783e51a..c1f89ca3a1 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -219,6 +219,35 @@ def test_matrix_inverse_solve(): ) +@pytest.mark.parametrize( + "batch_shape", + [(), (5,), (3, 5)], + ids=["no_batch", "single_batch", "multi_batch"], +) +def test_batched_matrix_inverse_solve(batch_shape): + """Test that inv_as_solve fires for batched (matmul) operations.""" + n = 4 + shape = (*batch_shape, n, n) + A = pt.tensor("A", shape=shape, dtype="float64") + b = pt.tensor("b", shape=shape, dtype="float64") + + # inv(A) @ b should be rewritten to solve(A, b) + out = matmul(pt.linalg.inv(A), b) + f = pytensor.function([A, b], out, mode="FAST_RUN") + + # Graph check: should contain Solve, not MatrixInverse + nodes = f.maker.fgraph.apply_nodes + has_solve = any( + isinstance(getattr(node.op, "core_op", node.op), Solve) for node in nodes + ) + has_inv = any( + isinstance(getattr(node.op, "core_op", node.op), MatrixInverse) + for node in nodes + ) + assert has_solve, "Expected Solve in the rewritten graph" + assert not has_inv, "MatrixInverse should have been rewritten away" + + @pytest.mark.parametrize("tag", ("lower", "upper", None)) @pytest.mark.parametrize("cholesky_form", ("lower", "upper")) @pytest.mark.parametrize("product", ("lower", "upper", None))