ENH: Implement inv(A) rewrite for triangular matrices (GSoC 2026)#1927
Open
Youssef-Naggar wants to merge 1 commit intopymc-devs:mainfrom
Open
ENH: Implement inv(A) rewrite for triangular matrices (GSoC 2026)#1927Youssef-Naggar wants to merge 1 commit intopymc-devs:mainfrom
inv(A) rewrite for triangular matrices (GSoC 2026)#1927Youssef-Naggar wants to merge 1 commit intopymc-devs:mainfrom
Conversation
bd50d6b to
f4989bf
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Title: ENH: Implement
inv(A)rewrite for triangular matrices (GSoC 2026)What is this PR about?
Hi everyone! I am a CS student highly interested in joining the GSoC 2026 program to work on the "Linear algebra rewrites" project with PyTensor. To familiarize myself with the PyTensor codebase, the graph rewriting system, and the testing framework, I decided to tackle one of the optimizations mentioned in the COLA library issue tracker.
This PR implements a graph rewrite that replaces the standard
MatrixInverseoperation with the highly optimizedSolveTriangularoperation whenever a matrix is tagged as lower or upper triangular. I would love any feedback or notes on my implementation to help strengthen my GSoC proposal!Major / Breaking Changes
None.
New features
rewrite_inv_triangular_to_solveinpytensor/tensor/rewriting/linalg.py.lower_triangularorupper_triangulartags and constructs an identity matrix of matching shape and dtype to route the computation throughsolve_triangular.Bugfixes
None.
Documentation
rewrite_inv_triangular_to_solverewrite.Maintenance
test_triangular_inv_rewrite_and_gradtotests/tensor/rewriting/test_linalg.pyto ensure the graph successfully rewrites and that gradients (verify_grad) are computed correctly.Results of run `tests/tensor/rewriting/test_linalg.py`
``` (.venv) PS C:\self-study\gsoc\pytensor> pytest C:\self-study\gsoc\pytensor\tests\tensor\rewriting\test_linalg.py ======================================================================================================= test session starts ======================================================================================================= platform win32 -- Python 3.14.0, pytest-9.0.2, pluggy-1.6.0 benchmark: 5.2.3 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000) rootdir: C:\self-study\gsoc\pytensor configfile: pyproject.toml plugins: benchmark-5.2.3, cov-7.0.0, mock-3.15.1, sphinx-0.7.1 collected 133 itemstests\tensor\rewriting\test_linalg.py ..................................................................................................................................... [100%]
======================================================================================================== warnings summary =========================================================================================================
tests/tensor/rewriting/test_linalg.py::test_nested_blockdiag_fusion
C:\self-study\gsoc\pytensor\pytensor\graph\rewriting\basic.py:110: UserWarning: A Supervisor feature is missing from FunctionGraph(BlockDiagonal{n_inputs=3}(x, y, z)).
This is needed for inplace rewrites. Either exclude inplace rewrites or add a Supervisor feature.
A Supervisor feature can be added via
pytensor.compile.function.types.add_supervisor_to_fgraph.return self.apply(fgraph, *args, **kwargs)
tests/tensor/rewriting/test_linalg.py::test_deeply_nested_blockdiag_fusion
C:\self-study\gsoc\pytensor\pytensor\graph\rewriting\basic.py:110: UserWarning: A Supervisor feature is missing from FunctionGraph(BlockDiagonal{n_inputs=4}(x, y, z, w)).
This is needed for inplace rewrites. Either exclude inplace rewrites or add a Supervisor feature.
A Supervisor feature can be added via
pytensor.compile.function.types.add_supervisor_to_fgraph.return self.apply(fgraph, *args, **kwargs)
tests/tensor/rewriting/test_linalg.py::test_matrix_inverse_rop_lop
C:\self-study\gsoc\pytensor\pytensor\link\c\cmodule.py:2986: UserWarning: PyTensor could not link to a BLAS installation. Operations that might benefit from BLAS will be severely degraded.
This usually happens when PyTensor is installed via pip. We recommend it be installed via conda/mamba/pixi instead.
Alternatively, you can use an experimental backend such as Numba or JAX that perform their own BLAS optimizations, by setting
pytensor.config.mode == 'NUMBA'or passingmode='NUMBA'when compiling a PyTensor function.For more options and details see https://pytensor.readthedocs.io/en/latest/troubleshooting.html#how-do-i-configure-test-my-blas-library
warnings.warn(
tests/tensor/rewriting/test_linalg.py::test_matrix_inverse_rop_lop
C:\self-study\gsoc\pytensor\tests\tensor\rewriting\test_linalg.py:124: DeprecationWarning: Scan return signature will change. Updates dict will not be returned, only the first argument. Pass
return_updates=Falseto conform to the new API and avoid this warningsy, _ = pytensor.scan(
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
====================================================================================================== slowest 50 durations =======================================================================================================
22.05s call tests/tensor/rewriting/test_linalg.py::test_matrix_inverse_rop_lop
14.57s call tests/tensor/rewriting/test_linalg.py::test_det_blockdiag_rewrite
5.05s call tests/tensor/rewriting/test_linalg.py::TestBatchedVectorBSolveToMatrixBSolve::test_valid_cases[solve]
4.49s call tests/tensor/rewriting/test_linalg.py::test_triangular_inv_rewrite_and_grad
4.03s call tests/tensor/rewriting/test_linalg.py::test_det_kronecker_rewrite
3.84s call tests/tensor/rewriting/test_linalg.py::test_slogdet_blockdiag_rewrite
3.30s call tests/tensor/rewriting/test_linalg.py::test_slogdet_specialization
3.19s call tests/tensor/rewriting/test_linalg.py::test_local_det_chol
3.19s call tests/tensor/rewriting/test_linalg.py::test_local_lift_through_linalg[kron-inv-not_batched]
3.17s call tests/tensor/rewriting/test_linalg.py::test_diag_kronecker_rewrite
2.91s call tests/tensor/rewriting/test_linalg.py::test_cholesky_diag_from_eye_mul[vector]
2.90s call tests/tensor/rewriting/test_linalg.py::test_cholesky_diag_from_eye_mul[batched]
2.85s call tests/tensor/rewriting/test_linalg.py::test_cholesky_diag_from_eye_mul[scalar]
2.60s call tests/tensor/rewriting/test_linalg.py::test_diag_blockdiag_rewrite
2.49s call tests/tensor/rewriting/test_linalg.py::test_det_diag_from_eye_mul[batched_input]
2.34s call tests/tensor/rewriting/test_linalg.py::test_local_lift_through_linalg[kron-inv-batched]
2.25s call tests/tensor/rewriting/test_linalg.py::test_inv_diag_from_diag[inv]
1.71s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[Solve-solve-extra_kwargs0-b_ndim=2-b_batch_shape=(5,)-a_batch_shape=()]
1.68s call tests/tensor/rewriting/test_linalg.py::test_inv_diag_from_eye_mul[inv-batched]
1.62s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[Solve-solve-extra_kwargs0-b_ndim=1-b_batch_shape=()-a_batch_shape=(5,)]
1.61s call tests/tensor/rewriting/test_linalg.py::test_inv_diag_from_eye_mul[inv-scalar]
1.60s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[Solve-solve-extra_kwargs0-b_ndim=2-b_batch_shape=()-a_batch_shape=(5,)]
1.57s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[SolveTriangular-solve_triangular-extra_kwargs2-b_ndim=2-b_batch_shape=()-a_batch_shape=(5,)]
1.54s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[Solve-solve-extra_kwargs0-b_ndim=1-b_batch_shape=(5,)-a_batch_shape=(5,)]
1.51s call tests/tensor/rewriting/test_linalg.py::test_inv_diag_from_eye_mul[inv-vector]
1.50s call tests/tensor/rewriting/test_linalg.py::test_rewrite_cholesky_diag_to_sqrt_diag_not_applied
1.49s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[Solve-solve-extra_kwargs0-b_ndim=2-b_batch_shape=(5,)-a_batch_shape=(5,)]
1.47s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[CholeskySolve-cho_solve-extra_kwargs3-b_ndim=1-b_batch_shape=()-a_batch_shape=(5,)]
1.44s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[CholeskySolve-cho_solve-extra_kwargs3-b_ndim=2-b_batch_shape=()-a_batch_shape=(5,)]
1.43s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[CholeskySolve-cho_solve-extra_kwargs3-b_ndim=1-b_batch_shape=(5,)-a_batch_shape=()]
1.42s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[CholeskySolve-cho_solve-extra_kwargs3-b_ndim=2-b_batch_shape=(5,)-a_batch_shape=()]
1.39s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[SolveTriangular-solve_triangular-extra_kwargs2-b_ndim=1-b_batch_shape=()-a_batch_shape=(5,)]
1.33s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[Solve-solve-extra_kwargs0-b_ndim=1-b_batch_shape=()-a_batch_shape=()]
1.28s call tests/tensor/rewriting/test_linalg.py::test_det_diag_from_eye_mul[scalar]
1.24s call tests/tensor/rewriting/test_linalg.py::test_cholesky_diag_from_diag
1.23s call tests/tensor/rewriting/test_linalg.py::TestBatchedVectorBSolveToMatrixBSolve::test_invalid_batched_a
1.23s call tests/tensor/rewriting/test_linalg.py::test_det_diag_from_eye_mul[vector]
1.19s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[CholeskySolve-cho_solve-extra_kwargs3-b_ndim=1-b_batch_shape=()-a_batch_shape=()]
0.96s call tests/tensor/rewriting/test_linalg.py::test_cholesky_ldotlt[dot-lower-lower-lower]
0.50s call tests/tensor/rewriting/test_linalg.py::test_slogdet_kronecker_rewrite
0.37s call tests/tensor/rewriting/test_linalg.py::test_nested_blockdiag_fusion
0.23s call tests/tensor/rewriting/test_linalg.py::test_local_lift_through_linalg[kron-pinv-not_batched]
0.22s call tests/tensor/rewriting/test_linalg.py::test_local_lift_through_linalg[kron-cholesky-batched]
0.19s call tests/tensor/rewriting/test_linalg.py::test_local_lift_through_linalg[kron-pinv-batched]
0.18s call tests/tensor/rewriting/test_linalg.py::test_local_lift_through_linalg[kron-cholesky-not_batched]
0.17s call tests/tensor/rewriting/test_linalg.py::test_svd_uv_merge
0.12s call tests/tensor/rewriting/test_linalg.py::test_det_diag_from_diag
0.11s call tests/tensor/rewriting/test_linalg.py::test_det_diag_from_eye_mul[matrix]
0.09s call tests/tensor/rewriting/test_linalg.py::TestBatchedVectorBSolveToMatrixBSolve::test_valid_cases[cho_solve]
0.08s call tests/tensor/rewriting/test_linalg.py::test_inv_diag_from_diag[pinv]
=========================================================================================== 133 passed, 4 warnings in 140.64s (0:02:20) ===========================================================================================
Performance & Benchmarks
I created a local benchmark (
benchmark_inv.py) to test the speedup on a 1000x1000 lower triangular matrix. Bypassing the standardinvoperation and utilizing the rewrite yielded a range of 11x speedup (on average):my benchmark code (benchmark_inv.py)
import timeit import numpy as np import pytensor import pytensor.tensor as ptdef run_benchmark():
print("Setting up benchmark...")
N = 1000 # 1000x1000 matrix
np.random.seed(42)
if name == "main":
run_benchmark()
the output of my benchmark test
Benchmarking Matrix Inverse for 1000x1000 Triangular Matrix... Standard `inv` Time (10 runs): 1.1401 seconds Optimized `solve_triangular` Time (10 runs): 0.0995 secondsSpeedup: 11.45x faster!