Skip to content

Add numba dispatch for expm#2148

Merged
ricardoV94 merged 4 commits into
mainfrom
numba-expm
May 20, 2026
Merged

Add numba dispatch for expm#2148
ricardoV94 merged 4 commits into
mainfrom
numba-expm

Conversation

@jessegrabowski
Copy link
Copy Markdown
Member

It comes up in continuous time statespace models and ODEs

Performance:

  ┌─────┬─────────┬─────────┬──────────┬─────────────┬───────────┐
  │  n  │  scipy  │  numba  │   jax    │ numba/scipy │ jax/scipy │
  ├─────┼─────────┼─────────┼──────────┼─────────────┼───────────┤
  │ 4   │ 3.9 us  │ 2.7 us  │ 17.0 us  │ 0.70×       │ 4.31×     │
  ├─────┼─────────┼─────────┼──────────┼─────────────┼───────────┤
  │ 16  │ 24.4 us │ 8.0 us  │ 21.9 us  │ 0.33×       │ 0.90×     │
  ├─────┼─────────┼─────────┼──────────┼─────────────┼───────────┤
  │ 32  │ 46.6 us │ 18.5 us │ 47.4 us  │ 0.40×       │ 1.02×     │
  ├─────┼─────────┼─────────┼──────────┼─────────────┼───────────┤
  │ 64  │ 91.4 us │ 53.5 us │ 162 us   │ 0.59×       │ 1.77×     │
  ├─────┼─────────┼─────────┼──────────┼─────────────┼───────────┤
  │ 128 │ 337 us  │ 283 us  │ 575 us   │ 0.84×       │ 1.71×     │
  ├─────┼─────────┼─────────┼──────────┼─────────────┼───────────┤
  │ 256 │ 1892 us │ 1729 us │ 2243 us  │ 0.91×       │ 1.19×     │
  ├─────┼─────────┼─────────┼──────────┼─────────────┼───────────┤
  │ 512 │ 9079 us │ 9028 us │ 11754 us │ 0.99×       │ 1.29×     │
  └─────┴─────────┴─────────┴──────────┴─────────────┴───────────┘

@jessegrabowski jessegrabowski requested a review from ricardoV94 May 19, 2026 19:02
@jessegrabowski jessegrabowski added enhancement New feature or request numba performance linalg Linear algebra labels May 19, 2026
@jessegrabowski
Copy link
Copy Markdown
Member Author

shit pycharm made the branch on upstream again, sorry about that.

Comment thread pytensor/link/numba/dispatch/linalg/products.py Outdated
Comment thread pytensor/link/numba/dispatch/linalg/products.py Outdated
@jessegrabowski
Copy link
Copy Markdown
Member Author

how faster?

Microbench:

import numba
import numpy as np
from numba.np.linalg import _copy_to_fortran_order


@numba.njit
def stdlib(A):
    return _copy_to_fortran_order(A)


@numba.njit
def manual(A):
    n, m = A.shape
    out = np.empty((n, m), dtype=A.dtype).T  # f-contig view of fresh buffer
    for j in range(m):
        for i in range(n):
            out[i, j] = A[i, j]
    return out


def best_us(fn, A, repeat=5, number=200):
    t = Timer(lambda: fn(A))
    fn(A)  # warm
    return min(t.repeat(repeat=repeat, number=number)) / number * 1e6


def main():
    rng = np.random.default_rng(0)
    layouts = {
        "C-contig": lambda A: A,
        "F-contig": np.asfortranarray,
        "strided ": lambda A: np.repeat(A, 2, axis=0)[::2],
    }

    print(f"{'n':>5}  {'layout':<9} {'stdlib':>11} {'manual':>10} {'speedup':>8}")
    print("-" * 50)
    for n in (4, 16, 64, 256, 1024):
        A = rng.normal(size=(n, n))
        for label, prep in layouts.items():
            V = prep(A)
            assert np.array_equal(stdlib(V), manual(V))
            t_std = best_us(stdlib, V)
            t_man = best_us(manual, V)
            print(
                f"{n:>5}  {label:<9} {t_std:>8.2f} us {t_man:>7.2f} us"
                f"  {t_std / t_man:>6.2f}x"
            )
        print()

Results:

    n  layout         stdlib     manual  speedup
--------------------------------------------------
    4  C-contig      0.28 us    0.29 us    0.97x
    4  F-contig      0.29 us    0.27 us    1.04x
    4  strided       0.28 us    0.30 us    0.93x

   16  C-contig      0.36 us    0.35 us    1.03x
   16  F-contig      0.32 us    0.29 us    1.10x
   16  strided       0.34 us    0.32 us    1.07x

   64  C-contig      1.31 us    1.35 us    0.97x
   64  F-contig      0.65 us    0.61 us    1.06x
   64  strided       1.24 us    1.17 us    1.07x

  256  C-contig    135.44 us   36.67 us    3.69x
  256  F-contig      6.01 us    6.67 us    0.90x
  256  strided     130.27 us   50.00 us    2.61x

 1024  C-contig   3426.08 us  935.88 us    3.66x
 1024  F-contig    106.63 us  117.28 us    0.91x
 1024  strided    3587.76 us 1535.61 us    2.34x

For large matrices where a copy is actually needed, it matters. Otherwise a push.

@jessegrabowski
Copy link
Copy Markdown
Member Author

jessegrabowski commented May 19, 2026

i was trying to wring every possible us out to get this the numba version to be as fast as scipy. They have a nice fused C kernel for doing this Pade scaling stuff, so it was non-trivial to beat. Happy to revert for now and revisit.

In fact i'd rather revert for now, it's a weird out of scope thing for this PR.

@ricardoV94
Copy link
Copy Markdown
Member

your strided test is only fair if you stride something that started as C and something that started as F, as the optimal order of iteration will depend on that. Although I don't know what's the worst case scenario since you are constrained by the output layout

@jessegrabowski jessegrabowski requested a review from ricardoV94 May 19, 2026 20:58
@ricardoV94 ricardoV94 merged commit 5582499 into main May 20, 2026
66 checks passed
@ricardoV94 ricardoV94 deleted the numba-expm branch May 20, 2026 15:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request linalg Linear algebra numba performance

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants