From 81e0f280036f13cb7a8306ee190a432dbdee6636 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 19 May 2026 13:21:16 -0500 Subject: [PATCH 1/4] Allow inplace on Expm --- pytensor/tensor/linalg/products.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/linalg/products.py b/pytensor/tensor/linalg/products.py index 471ad51ce7..061be2ddf5 100644 --- a/pytensor/tensor/linalg/products.py +++ b/pytensor/tensor/linalg/products.py @@ -15,9 +15,14 @@ class Expm(Op): Compute the matrix exponential of a square array. """ - __props__ = () + __props__ = ("overwrite_a",) gufunc_signature = "(m,m)->(m,m)" + def __init__(self, overwrite_a: bool = False): + self.overwrite_a = overwrite_a + if self.overwrite_a: + self.destroy_map = {0: [0]} + def make_node(self, A): A = as_tensor_variable(A) assert A.ndim == 2 @@ -31,6 +36,13 @@ def perform(self, node, inputs, outputs): (expm,) = outputs expm[0] = scipy_linalg.expm(A) + def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": + if not allowed_inplace_inputs: + return self + new_props = self._props_dict() # type: ignore + new_props["overwrite_a"] = True + return type(self)(**new_props) + def pullback(self, inputs, outputs, output_grads): from pytensor.tensor.linalg.solvers.general import solve From ba6521aa5f96feb022646b3cdede71cc5cafdabd Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 19 May 2026 13:59:24 -0500 Subject: [PATCH 2/4] Add numba dispatch for expm --- pytensor/link/numba/dispatch/basic.py | 2 +- .../link/numba/dispatch/linalg/__init__.py | 1 + .../link/numba/dispatch/linalg/products.py | 341 ++++++++++++++++++ tests/link/numba/linalg/test_products.py | 77 ++++ 4 files changed, 420 insertions(+), 1 deletion(-) create mode 100644 pytensor/link/numba/dispatch/linalg/products.py create mode 100644 tests/link/numba/linalg/test_products.py diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 6a01532380..c0abd7bebe 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -48,7 +48,7 @@ def _filter_numba_warnings(): "ignore", message=( "(\x1b\\[1m)*" # ansi escape code for bold text - r"np\.dot\(\) is faster on contiguous arrays" + r"(np\.dot\(\)|np\.vdot\(\)|'@') is faster on contiguous arrays" ), category=NumbaPerformanceWarning, ) diff --git a/pytensor/link/numba/dispatch/linalg/__init__.py b/pytensor/link/numba/dispatch/linalg/__init__.py index 4b0f38bc40..86612a2536 100644 --- a/pytensor/link/numba/dispatch/linalg/__init__.py +++ b/pytensor/link/numba/dispatch/linalg/__init__.py @@ -1,5 +1,6 @@ import pytensor.link.numba.dispatch.linalg.constructors import pytensor.link.numba.dispatch.linalg.decomposition.dispatch import pytensor.link.numba.dispatch.linalg.inverse +import pytensor.link.numba.dispatch.linalg.products import pytensor.link.numba.dispatch.linalg.solvers.dispatch import pytensor.link.numba.dispatch.linalg.summary diff --git a/pytensor/link/numba/dispatch/linalg/products.py b/pytensor/link/numba/dispatch/linalg/products.py new file mode 100644 index 0000000000..d0b551becd --- /dev/null +++ b/pytensor/link/numba/dispatch/linalg/products.py @@ -0,0 +1,341 @@ +import numpy as np +from numba.core.extending import overload +from numba.core.types import Complex, Float +from numba.np.linalg import ensure_lapack +from scipy import linalg + +from pytensor import config +from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key +from pytensor.link.numba.dispatch.linalg._LAPACK import ( + _LAPACK, + _get_underlying_float, + int_ptr_to_val, + val_to_int_ptr, +) +from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix +from pytensor.tensor.linalg.products import Expm + + +@numba_basic.numba_njit(inline="always") +def _copy_to_f(A): + # Faster than numba's stdlib _copy_to_fortran_order at moderate sizes. + n, m = A.shape + out = np.empty((n, m), dtype=A.dtype).T + for j in range(m): + for i in range(n): + out[i, j] = A[i, j] + return out + + +@numba_basic.numba_njit(inline="always") +def _poly2_id(c0, A0, c1, A1, id_c, out): + n = out.shape[0] + for i in range(n): + for j in range(n): + out[i, j] = c0 * A0[i, j] + c1 * A1[i, j] + for i in range(n): + out[i, i] += id_c + + +@numba_basic.numba_njit(inline="always") +def _poly3(c0, A0, c1, A1, c2, A2, out): + n = out.shape[0] + for i in range(n): + for j in range(n): + out[i, j] = c0 * A0[i, j] + c1 * A1[i, j] + c2 * A2[i, j] + + +@numba_basic.numba_njit(inline="always") +def _poly3_id(c0, A0, c1, A1, c2, A2, id_c, out): + n = out.shape[0] + for i in range(n): + for j in range(n): + out[i, j] = c0 * A0[i, j] + c1 * A1[i, j] + c2 * A2[i, j] + for i in range(n): + out[i, i] += id_c + + +@numba_basic.numba_njit(inline="always") +def _poly4_id(c0, A0, c1, A1, c2, A2, c3, A3, id_c, out): + n = out.shape[0] + for i in range(n): + for j in range(n): + out[i, j] = c0 * A0[i, j] + c1 * A1[i, j] + c2 * A2[i, j] + c3 * A3[i, j] + for i in range(n): + out[i, i] += id_c + + +def _expm(A, overwrite_a=False): + return linalg.expm(A) + + +@overload(_expm) +def _expm_impl(A, overwrite_a): + # Al-Mohy & Higham 2009 Pade scaling-and-squaring (Tables 2.3, 3.1). + ensure_lapack() + _check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="expm") + + real_dtype = _get_underlying_float(A.dtype) + is_single = real_dtype == np.float32 + + numba_xgetrf = _LAPACK().numba_xgetrf(A.dtype) + numba_xgetrs = _LAPACK().numba_xgetrs(A.dtype) + + if is_single: + theta_max = real_dtype.type(3.925724783138660) + theta_3 = real_dtype.type(4.258730016922831e-01) + theta_5 = real_dtype.type(1.880152677804762e00) + theta_7 = real_dtype.type(3.925724783138660) + theta_9 = real_dtype.type(3.925724783138660) + else: + theta_max = real_dtype.type(5.371920351148152) + theta_3 = real_dtype.type(1.495585217958292e-02) + theta_5 = real_dtype.type(2.539398330063230e-01) + theta_7 = real_dtype.type(9.504178996162932e-01) + theta_9 = real_dtype.type(2.097847961257068e00) + + b3 = tuple(real_dtype.type(x) for x in (120.0, 60.0, 12.0, 1.0)) + b5 = tuple(real_dtype.type(x) for x in (30240.0, 15120.0, 3360.0, 420.0, 30.0, 1.0)) + b7 = tuple( + real_dtype.type(x) + for x in ( + 17297280.0, + 8648640.0, + 1995840.0, + 277200.0, + 25200.0, + 1512.0, + 56.0, + 1.0, + ) + ) + b9 = tuple( + real_dtype.type(x) + for x in ( + 17643225600.0, + 8821612800.0, + 2075673600.0, + 302702400.0, + 30270240.0, + 2162160.0, + 110880.0, + 3960.0, + 90.0, + 1.0, + ) + ) + b13 = tuple( + real_dtype.type(x) + for x in ( + 64764752532480000.0, + 32382376266240000.0, + 7771770303897600.0, + 1187353796428800.0, + 129060195264000.0, + 10559470521600.0, + 670442572800.0, + 33522128640.0, + 1323241920.0, + 40840800.0, + 960960.0, + 16380.0, + 182.0, + 1.0, + ) + ) + + def impl(A, overwrite_a): + n = A.shape[-1] + + A_L1 = np.linalg.norm(A, 1) + + if A_L1 > theta_max: + s = int(np.ceil(np.log2(A_L1 / theta_max))) + else: + s = 0 + + # expm(X.T) = expm(X).T -- run the kernel on A.T when A is c-contig so + # we get an f-contig view of the input buffer for free. + transposed = False + if A.flags.c_contiguous: + A_s = A.T if overwrite_a else A.copy().T + transposed = True + elif overwrite_a and A.flags.f_contiguous: + A_s = A + else: + A_s = _copy_to_f(A) + + A_s = np.asfortranarray(A_s) + + if s > 0: + A_s /= real_dtype.type(2.0) ** s + + norm_scaled = A_L1 / (real_dtype.type(2.0) ** s) + + dtype = A_s.dtype + A2 = np.empty((n, n), dtype=dtype) + np.dot(A_s, A_s, A2) + U = np.empty((n, n), dtype=dtype) + V = np.empty((n, n), dtype=dtype) + S = np.empty((n, n), dtype=dtype) + T = np.empty((n, n), dtype=dtype).T # f-contig, consumed by getrs + + if is_single: + if norm_scaled <= theta_3: + # U = A_s @ (b3[3]*A2 + b3[1]*I); V = b3[2]*A2 + b3[0]*I + np.multiply(b3[3], A2, S) + for i in range(n): + S[i, i] += b3[1] + np.dot(A_s, S, U) + np.multiply(b3[2], A2, V) + for i in range(n): + V[i, i] += b3[0] + elif norm_scaled <= theta_5: + # U = A_s @ (b5[5]*A4 + b5[3]*A2 + b5[1]*I) + # V = b5[4]*A4 + b5[2]*A2 + b5[0]*I + A4 = np.empty((n, n), dtype=dtype) + np.dot(A2, A2, A4) + _poly2_id(b5[5], A4, b5[3], A2, b5[1], S) + np.dot(A_s, S, U) + _poly2_id(b5[4], A4, b5[2], A2, b5[0], V) + else: + # U = A_s @ (b7[7]*A6 + b7[5]*A4 + b7[3]*A2 + b7[1]*I) + # V = b7[6]*A6 + b7[4]*A4 + b7[2]*A2 + b7[0]*I + A4 = np.empty((n, n), dtype=dtype) + np.dot(A2, A2, A4) + A6 = np.empty((n, n), dtype=dtype) + np.dot(A4, A2, A6) + _poly3_id(b7[7], A6, b7[5], A4, b7[3], A2, b7[1], S) + np.dot(A_s, S, U) + _poly3_id(b7[6], A6, b7[4], A4, b7[2], A2, b7[0], V) + else: + if norm_scaled <= theta_3: + # U = A_s @ (b3[3]*A2 + b3[1]*I); V = b3[2]*A2 + b3[0]*I + np.multiply(b3[3], A2, S) + for i in range(n): + S[i, i] += b3[1] + np.dot(A_s, S, U) + np.multiply(b3[2], A2, V) + for i in range(n): + V[i, i] += b3[0] + elif norm_scaled <= theta_5: + # U = A_s @ (b5[5]*A4 + b5[3]*A2 + b5[1]*I) + # V = b5[4]*A4 + b5[2]*A2 + b5[0]*I + A4 = np.empty((n, n), dtype=dtype) + np.dot(A2, A2, A4) + _poly2_id(b5[5], A4, b5[3], A2, b5[1], S) + np.dot(A_s, S, U) + _poly2_id(b5[4], A4, b5[2], A2, b5[0], V) + elif norm_scaled <= theta_7: + # U = A_s @ (b7[7]*A6 + b7[5]*A4 + b7[3]*A2 + b7[1]*I) + # V = b7[6]*A6 + b7[4]*A4 + b7[2]*A2 + b7[0]*I + A4 = np.empty((n, n), dtype=dtype) + np.dot(A2, A2, A4) + A6 = np.empty((n, n), dtype=dtype) + np.dot(A4, A2, A6) + _poly3_id(b7[7], A6, b7[5], A4, b7[3], A2, b7[1], S) + np.dot(A_s, S, U) + _poly3_id(b7[6], A6, b7[4], A4, b7[2], A2, b7[0], V) + elif norm_scaled <= theta_9: + # U = A_s @ (b9[9]*A8 + b9[7]*A6 + b9[5]*A4 + b9[3]*A2 + b9[1]*I) + # V = b9[8]*A8 + b9[6]*A6 + b9[4]*A4 + b9[2]*A2 + b9[0]*I + A4 = np.empty((n, n), dtype=dtype) + np.dot(A2, A2, A4) + A6 = np.empty((n, n), dtype=dtype) + np.dot(A4, A2, A6) + A8 = np.empty((n, n), dtype=dtype) + np.dot(A6, A2, A8) + _poly4_id(b9[9], A8, b9[7], A6, b9[5], A4, b9[3], A2, b9[1], S) + np.dot(A_s, S, U) + _poly4_id(b9[8], A8, b9[6], A6, b9[4], A4, b9[2], A2, b9[0], V) + else: + # Pade 13 via Horner (Higham 2005 eqs. 2.2-2.3), so we never + # form A^8/A^10/A^12 explicitly. + # W1 = b13[13]*A6 + b13[11]*A4 + b13[9]*A2 + # W2 = b13[7]*A6 + b13[5]*A4 + b13[3]*A2 + b13[1]*I + # U = A_s @ (A6 @ W1 + W2) + # Z1 = b13[12]*A6 + b13[10]*A4 + b13[8]*A2 + # Z2 = b13[6]*A6 + b13[4]*A4 + b13[2]*A2 + b13[0]*I + # V = A6 @ Z1 + Z2 + A4 = np.empty((n, n), dtype=dtype) + np.dot(A2, A2, A4) + A6 = np.empty((n, n), dtype=dtype) + np.dot(A4, A2, A6) + _poly3(b13[13], A6, b13[11], A4, b13[9], A2, S) # S = W1 + _poly3_id(b13[7], A6, b13[5], A4, b13[3], A2, b13[1], U) # U = W2 + np.dot(A6, S, V) # V = A6 @ W1 + V += U # V = A6 @ W1 + W2 + np.dot(A_s, V, U) # U = A_s @ V (final U) + _poly3(b13[12], A6, b13[10], A4, b13[8], A2, S) # S = Z1 + np.dot(A6, S, V) # V = A6 @ Z1 + # V += Z2 fused with the np.dot output + for i in range(n): + for j in range(n): + V[i, j] += ( + b13[6] * A6[i, j] + b13[4] * A4[i, j] + b13[2] * A2[i, j] + ) + for i in range(n): + V[i, i] += b13[0] + + np.add(U, V, T) # T = P = U + V + V -= U # V = Q = V - U + + # Solve Q R = P -> V is c-contig; pass V.T as A and undo with TRANS='T'. + n_i32 = np.int32(n) + N_PTR = val_to_int_ptr(n_i32) + LDA = val_to_int_ptr(n_i32) + LDB = val_to_int_ptr(n_i32) + NRHS = val_to_int_ptr(n_i32) + TRANS = val_to_int_ptr(np.int32(ord("T"))) + INFO_RF = val_to_int_ptr(np.int32(0)) + INFO_RS = val_to_int_ptr(np.int32(0)) + IPIV = np.empty(n, dtype=np.int32) + V_T = V.T + + numba_xgetrf(N_PTR, N_PTR, V_T.ctypes, LDA, IPIV.ctypes, INFO_RF) + numba_xgetrs( + TRANS, N_PTR, NRHS, V_T.ctypes, LDA, IPIV.ctypes, T.ctypes, LDB, INFO_RS + ) + + R = T + if int_ptr_to_val(INFO_RF) != 0 or int_ptr_to_val(INFO_RS) != 0: + R[:] = np.nan + + if s > 0: + A2[:] = R + R = A2 + R_buf = U + for _ in range(s): + np.dot(R, R, R_buf) + R, R_buf = R_buf, R + + if transposed: + return R.T + return R + + return impl + + +@register_funcify_default_op_cache_key(Expm) +def numba_funcify_Expm(op, node, **kwargs): + overwrite_a = op.overwrite_a + + inp_dtype = node.inputs[0].type.numpy_dtype + discrete_input = inp_dtype.kind in "ibu" + if discrete_input and config.compiler_verbose: + print("Expm requires casting discrete input to float") # noqa: T201 + + out_dtype = node.outputs[0].type.numpy_dtype + effective_overwrite_a = overwrite_a and not discrete_input + + @numba_basic.numba_njit + def expm(a): + if a.size == 0: + return np.zeros(a.shape, dtype=out_dtype) + if discrete_input: + a = a.astype(out_dtype) + return _expm(a, effective_overwrite_a) + + cache_version = 1 + return expm, cache_version diff --git a/tests/link/numba/linalg/test_products.py b/tests/link/numba/linalg/test_products.py new file mode 100644 index 0000000000..9ff8b9ca43 --- /dev/null +++ b/tests/link/numba/linalg/test_products.py @@ -0,0 +1,77 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor import In, config +from pytensor.tensor.linalg.products import Expm, expm +from tests.link.numba.test_basic import compare_numba_and_py, numba_inplace_mode + + +pytestmark = [ + pytest.mark.filterwarnings("error"), + pytest.mark.filterwarnings("ignore::numba.core.errors.NumbaPerformanceWarning"), +] + +numba = pytest.importorskip("numba") + +floatX = config.floatX + +rng = np.random.default_rng(42849) + + +class TestExpm: + @pytest.mark.parametrize("dtype", ["float32", "float64", "complex64", "complex128"]) + @pytest.mark.parametrize( + "overwrite_a", [False, True], ids=["no_overwrite", "overwrite_a"] + ) + def test_expm(self, overwrite_a: bool, dtype: str): + A = pt.matrix("A", dtype=dtype) + y = Expm(overwrite_a=overwrite_a)(A) + + x = rng.normal(size=(4, 4)) * 5.0 + if np.dtype(dtype).kind == "c": + x = x + 1j * rng.normal(size=(4, 4)) * 5.0 + val = x.astype(dtype) + rtol = 1e-3 if np.dtype(dtype).char in "fF" else 1e-10 + + def assert_fn(actual, expected): + np.testing.assert_allclose(actual, expected, rtol=rtol) + + fn, res = compare_numba_and_py( + [In(A, mutable=overwrite_a)], + [y], + [val], + numba_mode=numba_inplace_mode, + inplace=True, + assert_fn=assert_fn, + ) + + op = fn.maker.fgraph.outputs[0].owner.op + assert isinstance(op, Expm) + assert overwrite_a == (op.destroy_map == {0: [0]}) + + # F-contiguous input is mutated when overwrite_a=True (kernel uses + # A's buffer directly as scratch during scaling). + val_f_contig = np.copy(val, order="F") + res_f_contig = fn(val_f_contig) + np.testing.assert_allclose(res_f_contig, res, rtol=rtol) + assert (val == val_f_contig).all() == (not overwrite_a) + + # C-contiguous input is also mutated when overwrite_a=True: the kernel + # takes A.T (f-contig view of A's buffer) and computes expm(A.T) = + # expm(A).T, scaling A's buffer in place along the way. + val_c_contig = np.copy(val, order="C") + res_c_contig = fn(val_c_contig) + np.testing.assert_allclose(res_c_contig, res, rtol=rtol) + assert (val == val_c_contig).all() == (not overwrite_a) + + # Non-contiguous (strided) input is also never mutated. + val_not_contig = np.repeat(val, 2, axis=0)[::2] + res_not_contig = fn(val_not_contig) + np.testing.assert_allclose(res_not_contig, res, rtol=rtol) + np.testing.assert_allclose(val_not_contig, val) + + def test_expm_size_zero(self): + A = pt.matrix("A", dtype=floatX) + y = expm(A) + compare_numba_and_py([A], [y], [np.zeros((0, 0), dtype=floatX)]) From 7efb9ac3e7947f2ccb9a95d9f5dfd60414f81526 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 19 May 2026 15:11:03 -0500 Subject: [PATCH 3/4] Remove `_copy_to_f`, use `numba.np.linalg._copy_to_fortran_order,` --- pytensor/link/numba/dispatch/linalg/products.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/pytensor/link/numba/dispatch/linalg/products.py b/pytensor/link/numba/dispatch/linalg/products.py index d0b551becd..49390c34b2 100644 --- a/pytensor/link/numba/dispatch/linalg/products.py +++ b/pytensor/link/numba/dispatch/linalg/products.py @@ -1,7 +1,7 @@ import numpy as np from numba.core.extending import overload from numba.core.types import Complex, Float -from numba.np.linalg import ensure_lapack +from numba.np.linalg import _copy_to_fortran_order, ensure_lapack from scipy import linalg from pytensor import config @@ -17,17 +17,6 @@ from pytensor.tensor.linalg.products import Expm -@numba_basic.numba_njit(inline="always") -def _copy_to_f(A): - # Faster than numba's stdlib _copy_to_fortran_order at moderate sizes. - n, m = A.shape - out = np.empty((n, m), dtype=A.dtype).T - for j in range(m): - for i in range(n): - out[i, j] = A[i, j] - return out - - @numba_basic.numba_njit(inline="always") def _poly2_id(c0, A0, c1, A1, id_c, out): n = out.shape[0] @@ -164,7 +153,7 @@ def impl(A, overwrite_a): elif overwrite_a and A.flags.f_contiguous: A_s = A else: - A_s = _copy_to_f(A) + A_s = _copy_to_fortran_order(A) A_s = np.asfortranarray(A_s) From 25a7304f61c229734a6dc2bb3fd213990fb0dea0 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 19 May 2026 15:19:40 -0500 Subject: [PATCH 4/4] Fix integer inputs --- pytensor/link/numba/dispatch/linalg/products.py | 2 +- pytensor/tensor/linalg/products.py | 4 +++- tests/link/numba/linalg/test_products.py | 11 +++++++++++ 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/pytensor/link/numba/dispatch/linalg/products.py b/pytensor/link/numba/dispatch/linalg/products.py index 49390c34b2..7c2561d6fb 100644 --- a/pytensor/link/numba/dispatch/linalg/products.py +++ b/pytensor/link/numba/dispatch/linalg/products.py @@ -316,7 +316,7 @@ def numba_funcify_Expm(op, node, **kwargs): print("Expm requires casting discrete input to float") # noqa: T201 out_dtype = node.outputs[0].type.numpy_dtype - effective_overwrite_a = overwrite_a and not discrete_input + effective_overwrite_a = overwrite_a or discrete_input @numba_basic.numba_njit def expm(a): diff --git a/pytensor/tensor/linalg/products.py b/pytensor/tensor/linalg/products.py index 061be2ddf5..a5a33a1dcc 100644 --- a/pytensor/tensor/linalg/products.py +++ b/pytensor/tensor/linalg/products.py @@ -6,6 +6,7 @@ from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.linalg._lazy import scipy_linalg +from pytensor.tensor.linalg.dtype_utils import linalg_output_dtype from pytensor.tensor.symbolic import TensorSymbolicOp from pytensor.tensor.type import matrix @@ -27,7 +28,8 @@ def make_node(self, A): A = as_tensor_variable(A) assert A.ndim == 2 - expm = matrix(dtype=A.dtype, shape=A.type.shape) + dtype = linalg_output_dtype(A.type.dtype) + expm = matrix(dtype=dtype, shape=A.type.shape) return Apply(self, [A], [expm]) diff --git a/tests/link/numba/linalg/test_products.py b/tests/link/numba/linalg/test_products.py index 9ff8b9ca43..3e09b96cea 100644 --- a/tests/link/numba/linalg/test_products.py +++ b/tests/link/numba/linalg/test_products.py @@ -75,3 +75,14 @@ def test_expm_size_zero(self): A = pt.matrix("A", dtype=floatX) y = expm(A) compare_numba_and_py([A], [y], [np.zeros((0, 0), dtype=floatX)]) + + def test_expm_integer_input(self): + A = pt.matrix("A", dtype="int64") + y = expm(A) + assert y.type.dtype == "float64" + + val = rng.integers(-2, 3, size=(4, 4)).astype("int64") + original = val.copy() + _, res = compare_numba_and_py([A], [y], [val]) + np.testing.assert_array_equal(val, original) + assert res[0].dtype == np.float64