Skip to content

ENH: Implement inv(A) rewrite for triangular matrices (GSoC 2026)#1927

Open
Youssef-Naggar wants to merge 1 commit intopymc-devs:mainfrom
Youssef-Naggar:rewrite-inv-triangular
Open

ENH: Implement inv(A) rewrite for triangular matrices (GSoC 2026)#1927
Youssef-Naggar wants to merge 1 commit intopymc-devs:mainfrom
Youssef-Naggar:rewrite-inv-triangular

Conversation

@Youssef-Naggar
Copy link

Title: ENH: Implement inv(A) rewrite for triangular matrices (GSoC 2026)

What is this PR about?

Hi everyone! I am a CS student highly interested in joining the GSoC 2026 program to work on the "Linear algebra rewrites" project with PyTensor. To familiarize myself with the PyTensor codebase, the graph rewriting system, and the testing framework, I decided to tackle one of the optimizations mentioned in the COLA library issue tracker.

This PR implements a graph rewrite that replaces the standard MatrixInverse operation with the highly optimized SolveTriangular operation whenever a matrix is tagged as lower or upper triangular. I would love any feedback or notes on my implementation to help strengthen my GSoC proposal!

Major / Breaking Changes

None.

New features

  • Implemented rewrite_inv_triangular_to_solve in pytensor/tensor/rewriting/linalg.py.
  • The rewrite dynamically checks for lower_triangular or upper_triangular tags and constructs an identity matrix of matching shape and dtype to route the computation through solve_triangular.

Bugfixes

None.

Documentation

  • Added docstrings explaining the mathematical justification and logic for the rewrite_inv_triangular_to_solve rewrite.

Maintenance

  • Added test_triangular_inv_rewrite_and_grad to tests/tensor/rewriting/test_linalg.py to ensure the graph successfully rewrites and that gradients (verify_grad) are computed correctly.
``` Results of run `tests/tensor/rewriting/test_linalg.py` ``` (.venv) PS C:\self-study\gsoc\pytensor> pytest C:\self-study\gsoc\pytensor\tests\tensor\rewriting\test_linalg.py ======================================================================================================= test session starts ======================================================================================================= platform win32 -- Python 3.14.0, pytest-9.0.2, pluggy-1.6.0 benchmark: 5.2.3 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000) rootdir: C:\self-study\gsoc\pytensor configfile: pyproject.toml plugins: benchmark-5.2.3, cov-7.0.0, mock-3.15.1, sphinx-0.7.1 collected 133 items

tests\tensor\rewriting\test_linalg.py ..................................................................................................................................... [100%]

======================================================================================================== warnings summary =========================================================================================================
tests/tensor/rewriting/test_linalg.py::test_nested_blockdiag_fusion
C:\self-study\gsoc\pytensor\pytensor\graph\rewriting\basic.py:110: UserWarning: A Supervisor feature is missing from FunctionGraph(BlockDiagonal{n_inputs=3}(x, y, z)).
This is needed for inplace rewrites. Either exclude inplace rewrites or add a Supervisor feature.
A Supervisor feature can be added via pytensor.compile.function.types.add_supervisor_to_fgraph.
return self.apply(fgraph, *args, **kwargs)

tests/tensor/rewriting/test_linalg.py::test_deeply_nested_blockdiag_fusion
C:\self-study\gsoc\pytensor\pytensor\graph\rewriting\basic.py:110: UserWarning: A Supervisor feature is missing from FunctionGraph(BlockDiagonal{n_inputs=4}(x, y, z, w)).
This is needed for inplace rewrites. Either exclude inplace rewrites or add a Supervisor feature.
A Supervisor feature can be added via pytensor.compile.function.types.add_supervisor_to_fgraph.
return self.apply(fgraph, *args, **kwargs)

tests/tensor/rewriting/test_linalg.py::test_matrix_inverse_rop_lop
C:\self-study\gsoc\pytensor\pytensor\link\c\cmodule.py:2986: UserWarning: PyTensor could not link to a BLAS installation. Operations that might benefit from BLAS will be severely degraded.
This usually happens when PyTensor is installed via pip. We recommend it be installed via conda/mamba/pixi instead.
Alternatively, you can use an experimental backend such as Numba or JAX that perform their own BLAS optimizations, by setting pytensor.config.mode == 'NUMBA' or passing mode='NUMBA' when compiling a PyTensor function.
For more options and details see https://pytensor.readthedocs.io/en/latest/troubleshooting.html#how-do-i-configure-test-my-blas-library
warnings.warn(

tests/tensor/rewriting/test_linalg.py::test_matrix_inverse_rop_lop
C:\self-study\gsoc\pytensor\tests\tensor\rewriting\test_linalg.py:124: DeprecationWarning: Scan return signature will change. Updates dict will not be returned, only the first argument. Pass return_updates=False to conform to the new API and avoid this warning
sy, _ = pytensor.scan(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
====================================================================================================== slowest 50 durations =======================================================================================================
22.05s call tests/tensor/rewriting/test_linalg.py::test_matrix_inverse_rop_lop
14.57s call tests/tensor/rewriting/test_linalg.py::test_det_blockdiag_rewrite
5.05s call tests/tensor/rewriting/test_linalg.py::TestBatchedVectorBSolveToMatrixBSolve::test_valid_cases[solve]
4.49s call tests/tensor/rewriting/test_linalg.py::test_triangular_inv_rewrite_and_grad
4.03s call tests/tensor/rewriting/test_linalg.py::test_det_kronecker_rewrite
3.84s call tests/tensor/rewriting/test_linalg.py::test_slogdet_blockdiag_rewrite
3.30s call tests/tensor/rewriting/test_linalg.py::test_slogdet_specialization
3.19s call tests/tensor/rewriting/test_linalg.py::test_local_det_chol
3.19s call tests/tensor/rewriting/test_linalg.py::test_local_lift_through_linalg[kron-inv-not_batched]
3.17s call tests/tensor/rewriting/test_linalg.py::test_diag_kronecker_rewrite
2.91s call tests/tensor/rewriting/test_linalg.py::test_cholesky_diag_from_eye_mul[vector]
2.90s call tests/tensor/rewriting/test_linalg.py::test_cholesky_diag_from_eye_mul[batched]
2.85s call tests/tensor/rewriting/test_linalg.py::test_cholesky_diag_from_eye_mul[scalar]
2.60s call tests/tensor/rewriting/test_linalg.py::test_diag_blockdiag_rewrite
2.49s call tests/tensor/rewriting/test_linalg.py::test_det_diag_from_eye_mul[batched_input]
2.34s call tests/tensor/rewriting/test_linalg.py::test_local_lift_through_linalg[kron-inv-batched]
2.25s call tests/tensor/rewriting/test_linalg.py::test_inv_diag_from_diag[inv]
1.71s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[Solve-solve-extra_kwargs0-b_ndim=2-b_batch_shape=(5,)-a_batch_shape=()]
1.68s call tests/tensor/rewriting/test_linalg.py::test_inv_diag_from_eye_mul[inv-batched]
1.62s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[Solve-solve-extra_kwargs0-b_ndim=1-b_batch_shape=()-a_batch_shape=(5,)]
1.61s call tests/tensor/rewriting/test_linalg.py::test_inv_diag_from_eye_mul[inv-scalar]
1.60s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[Solve-solve-extra_kwargs0-b_ndim=2-b_batch_shape=()-a_batch_shape=(5,)]
1.57s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[SolveTriangular-solve_triangular-extra_kwargs2-b_ndim=2-b_batch_shape=()-a_batch_shape=(5,)]
1.54s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[Solve-solve-extra_kwargs0-b_ndim=1-b_batch_shape=(5,)-a_batch_shape=(5,)]
1.51s call tests/tensor/rewriting/test_linalg.py::test_inv_diag_from_eye_mul[inv-vector]
1.50s call tests/tensor/rewriting/test_linalg.py::test_rewrite_cholesky_diag_to_sqrt_diag_not_applied
1.49s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[Solve-solve-extra_kwargs0-b_ndim=2-b_batch_shape=(5,)-a_batch_shape=(5,)]
1.47s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[CholeskySolve-cho_solve-extra_kwargs3-b_ndim=1-b_batch_shape=()-a_batch_shape=(5,)]
1.44s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[CholeskySolve-cho_solve-extra_kwargs3-b_ndim=2-b_batch_shape=()-a_batch_shape=(5,)]
1.43s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[CholeskySolve-cho_solve-extra_kwargs3-b_ndim=1-b_batch_shape=(5,)-a_batch_shape=()]
1.42s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[CholeskySolve-cho_solve-extra_kwargs3-b_ndim=2-b_batch_shape=(5,)-a_batch_shape=()]
1.39s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[SolveTriangular-solve_triangular-extra_kwargs2-b_ndim=1-b_batch_shape=()-a_batch_shape=(5,)]
1.33s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[Solve-solve-extra_kwargs0-b_ndim=1-b_batch_shape=()-a_batch_shape=()]
1.28s call tests/tensor/rewriting/test_linalg.py::test_det_diag_from_eye_mul[scalar]
1.24s call tests/tensor/rewriting/test_linalg.py::test_cholesky_diag_from_diag
1.23s call tests/tensor/rewriting/test_linalg.py::TestBatchedVectorBSolveToMatrixBSolve::test_invalid_batched_a
1.23s call tests/tensor/rewriting/test_linalg.py::test_det_diag_from_eye_mul[vector]
1.19s call tests/tensor/rewriting/test_linalg.py::test_scalar_solve_to_division_rewrite[CholeskySolve-cho_solve-extra_kwargs3-b_ndim=1-b_batch_shape=()-a_batch_shape=()]
0.96s call tests/tensor/rewriting/test_linalg.py::test_cholesky_ldotlt[dot-lower-lower-lower]
0.50s call tests/tensor/rewriting/test_linalg.py::test_slogdet_kronecker_rewrite
0.37s call tests/tensor/rewriting/test_linalg.py::test_nested_blockdiag_fusion
0.23s call tests/tensor/rewriting/test_linalg.py::test_local_lift_through_linalg[kron-pinv-not_batched]
0.22s call tests/tensor/rewriting/test_linalg.py::test_local_lift_through_linalg[kron-cholesky-batched]
0.19s call tests/tensor/rewriting/test_linalg.py::test_local_lift_through_linalg[kron-pinv-batched]
0.18s call tests/tensor/rewriting/test_linalg.py::test_local_lift_through_linalg[kron-cholesky-not_batched]
0.17s call tests/tensor/rewriting/test_linalg.py::test_svd_uv_merge
0.12s call tests/tensor/rewriting/test_linalg.py::test_det_diag_from_diag
0.11s call tests/tensor/rewriting/test_linalg.py::test_det_diag_from_eye_mul[matrix]
0.09s call tests/tensor/rewriting/test_linalg.py::TestBatchedVectorBSolveToMatrixBSolve::test_valid_cases[cho_solve]
0.08s call tests/tensor/rewriting/test_linalg.py::test_inv_diag_from_diag[pinv]
=========================================================================================== 133 passed, 4 warnings in 140.64s (0:02:20) ===========================================================================================

Performance & Benchmarks

I created a local benchmark (benchmark_inv.py) to test the speedup on a 1000x1000 lower triangular matrix. Bypassing the standard inv operation and utilizing the rewrite yielded a range of 11x speedup (on average):

my benchmark code (benchmark_inv.py) import timeit import numpy as np import pytensor import pytensor.tensor as pt

def run_benchmark():
print("Setting up benchmark...")
N = 1000 # 1000x1000 matrix
np.random.seed(42)

# Create a large lower triangular matrix  
# Add scalar to ensure diagonal is safely away from 0    A_val = np.tril(np.random.rand(N, N) + 10.0)  
  
# 1. Standard Inverse (Without Rewrite)  
A_std = pt.dmatrix('A_std')  
Z_std = pt.linalg.inv(A_std)  
  
# Compile standard inverse mapping, intentionally ignoring our custom rewrite  
mode_std = pytensor.compile.mode.get_default_mode().excluding("rewrite_inv_triangular_to_solve")  
f_std = pytensor.function([A_std], Z_std, mode=mode_std)  
  
# 2. Optimized Inverse (With Rewrite)  
A_opt = pt.dmatrix('A_opt')  
A_opt.tag.lower_triangular = True  
Z_opt = pt.linalg.inv(A_opt)  
  
# Compile optimized inverse mapping, explicitly including our custom rewrite  
mode_opt = pytensor.compile.mode.get_default_mode().including("rewrite_inv_triangular_to_solve")  
f_opt = pytensor.function([A_opt], Z_opt, mode=mode_opt)  
  
# Verify the rewrite actually applied  
opt_nodes = f_opt.maker.fgraph.toposort()  
from pytensor.tensor.slinalg import SolveTriangular  
is_rewritten = any(isinstance(getattr(node.op, "core_op", node.op), SolveTriangular) for node in opt_nodes)  
print(f"Rewrite applied successfully: {is_rewritten}")  
if not is_rewritten:  
    print("WARNING: The rewrite did NOT apply in the benchmark script!")  
  
print(f"\nBenchmarking Matrix Inverse for {N}x{N} Triangular Matrix...")  
  
# Warmup  
_ = f_std(A_val)  
_ = f_opt(A_val)  

# Benchmark Standard  
std_time = timeit.timeit(lambda: f_std(A_val), number=10)  
print(f"Standard `inv` Time (10 runs): {std_time:.4f} seconds")  
  
# Benchmark Rewrite  
opt_time = timeit.timeit(lambda: f_opt(A_val), number=10)  
print(f"Optimized `solve_triangular` Time (10 runs): {opt_time:.4f} seconds")  
  
print(f"\nSpeedup: {std_time / opt_time:.2f}x faster!")  

if name == "main":
run_benchmark()

the output of my benchmark test Benchmarking Matrix Inverse for 1000x1000 Triangular Matrix... Standard `inv` Time (10 runs): 1.1401 seconds Optimized `solve_triangular` Time (10 runs): 0.0995 seconds

Speedup: 11.45x faster!

@Youssef-Naggar Youssef-Naggar force-pushed the rewrite-inv-triangular branch from bd50d6b to f4989bf Compare March 2, 2026 19:58
@Youssef-Naggar Youssef-Naggar changed the base branch from v3 to main March 2, 2026 20:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant