Conversation
Implements graph rewrite that eliminates redundant monotonic function applications in argmax/argmin operations. For monotonically increasing functions, rewrites argmax(f(x)) → argmax(x) and argmin(f(x)) → argmin(x). For decreasing functions, flips operations: argmax(f(x)) → argmin(x) and argmin(f(x)) → argmax(x). Includes comprehensive tests.
There was a problem hiding this comment.
Pull request overview
This PR adds a new canonicalization rewrite to optimize solve when the left-hand side is an explicitly constructed diagonal matrix, and also introduces monotonicity metadata + rewrites for argmax/argmin/max/min that leverage monotonic scalar ops.
Changes:
- Add
rewrite_solve_diagto rewriteBlockwise(Solve)(pt.diag(d), b)into elementwise division. - Introduce a
Monotonicityenum on scalar ops and add canonicalization rewrites for monotonicargmax/argmin/max/minpatterns. - Add tests covering the new diagonal-solve rewrite and the monotonic argmax/argmin/max/min rewrites.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
pytensor/tensor/rewriting/linalg.py |
Adds rewrite_solve_diag to replace diagonal Solve with division. |
tests/tensor/rewriting/test_linalg.py |
Adds tests ensuring the diagonal-solve rewrite removes Blockwise(Solve) and matches expected outputs. |
pytensor/scalar/basic.py |
Introduces Monotonicity and annotates several scalar ops with monotonicity metadata. |
pytensor/tensor/rewriting/math.py |
Adds canonicalization rewrites for monotonic argmax/argmin and max/min patterns. |
tests/tensor/rewriting/test_math.py |
Adds tests validating the monotonic argmax/argmin/max/min rewrites and graph structure changes. |
| @register_canonicalize | ||
| @node_rewriter([Argmax]) | ||
| def local_argmax_argmin_monotonic(fgraph, node): | ||
| """ | ||
| Optimize argmax/argmin with monotonic functions: | ||
| - argmax(f_inc(x)) -> argmax(x) for monotonically increasing f | ||
| - argmin(f_inc(x)) -> argmin(x) for monotonically increasing f | ||
| - argmax(f_dec(x)) -> argmin(x) for monotonically decreasing f | ||
| - argmin(f_dec(x)) -> argmax(x) for monotonically decreasing f | ||
| Note: argmin is represented as Argmax(Neg(...)) internally | ||
| """ |
There was a problem hiding this comment.
The PR description/title are focused on the diagonal solve rewrite, but this file also introduces new monotonicity-based canonicalization rewrites for argmax/argmin/max/min plus a new scalar-op API (Monotonicity). Please either update the PR description to cover these additional changes, or split them into a separate PR to keep review scope aligned.
| is_argmin = _is_argmin(node) | ||
| argmax_input = node.inputs[0] | ||
|
|
||
| # If argmin, skip the Neg wrapper to get to the monotonic function | ||
| if is_argmin: | ||
| if not argmax_input.owner: | ||
| return False | ||
| argmax_input = argmax_input.owner.inputs[0] # Skip Neg | ||
|
|
||
| if not argmax_input.owner: | ||
| return False | ||
|
|
||
| inner_op = argmax_input.owner.op | ||
|
|
||
| if not isinstance(inner_op, Elemwise): | ||
| return False | ||
|
|
||
| scalar_op = inner_op.scalar_op | ||
|
|
||
| monotonicity = getattr(scalar_op, "monotonicity", Monotonicity.NONMONOTONIC) | ||
| is_increasing = monotonicity == Monotonicity.INCREASING | ||
| is_decreasing = monotonicity == Monotonicity.DECREASING | ||
|
|
||
| if not (is_increasing or is_decreasing): | ||
| return False | ||
|
|
||
| x = argmax_input.owner.inputs[0] | ||
|
|
||
| # Determine new operation based on current op and monotonicity | ||
| if is_argmin: | ||
| if is_increasing: | ||
| # argmin(f_inc(x)) -> argmin(x) = Argmax(Neg(x)) | ||
| new_output = argmin(x, axis=node.op.axis) | ||
| else: # is_decreasing | ||
| # argmin(f_dec(x)) -> argmax(x) | ||
| new_output = node.op(x) | ||
| else: # is argmax | ||
| if is_increasing: | ||
| # argmax(f_inc(x)) -> argmax(x) | ||
| new_output = node.op(x) | ||
| else: # is_decreasing | ||
| # argmax(f_dec(x)) -> argmin(x) = Argmax(Neg(x)) | ||
| new_output = argmin(x, axis=node.op.axis) | ||
|
|
||
| copy_stack_trace(node.outputs[0], new_output) |
There was a problem hiding this comment.
These monotonicity-based rewrites are not semantics-preserving for many scalar ops in floating-point: they can change results in the presence of NaNs, infinities, domain errors (e.g., log, sqrt, arccos), or overflow/underflow that introduces ties (e.g., exp(1000) and exp(1001) both becoming inf). In those cases argmax(f(x))/argmin(f(x)) is not guaranteed to equal argmax(x)/argmin(x) due to different tie-breaking. This rewrite should be constrained to cases where f is proven total + strictly monotone over the realized inputs (or removed).
| is_argmin = _is_argmin(node) | |
| argmax_input = node.inputs[0] | |
| # If argmin, skip the Neg wrapper to get to the monotonic function | |
| if is_argmin: | |
| if not argmax_input.owner: | |
| return False | |
| argmax_input = argmax_input.owner.inputs[0] # Skip Neg | |
| if not argmax_input.owner: | |
| return False | |
| inner_op = argmax_input.owner.op | |
| if not isinstance(inner_op, Elemwise): | |
| return False | |
| scalar_op = inner_op.scalar_op | |
| monotonicity = getattr(scalar_op, "monotonicity", Monotonicity.NONMONOTONIC) | |
| is_increasing = monotonicity == Monotonicity.INCREASING | |
| is_decreasing = monotonicity == Monotonicity.DECREASING | |
| if not (is_increasing or is_decreasing): | |
| return False | |
| x = argmax_input.owner.inputs[0] | |
| # Determine new operation based on current op and monotonicity | |
| if is_argmin: | |
| if is_increasing: | |
| # argmin(f_inc(x)) -> argmin(x) = Argmax(Neg(x)) | |
| new_output = argmin(x, axis=node.op.axis) | |
| else: # is_decreasing | |
| # argmin(f_dec(x)) -> argmax(x) | |
| new_output = node.op(x) | |
| else: # is argmax | |
| if is_increasing: | |
| # argmax(f_inc(x)) -> argmax(x) | |
| new_output = node.op(x) | |
| else: # is_decreasing | |
| # argmax(f_dec(x)) -> argmin(x) = Argmax(Neg(x)) | |
| new_output = argmin(x, axis=node.op.axis) | |
| copy_stack_trace(node.outputs[0], new_output) | |
| # NOTE: | |
| # Monotonicity-based rewrites of argmax/argmin are not generally | |
| # semantics-preserving in floating-point arithmetic: NaNs, infinities, | |
| # domain errors, and over/underflow-induced ties can change which index | |
| # is returned. Until we can prove that the wrapped scalar op is total and | |
| # strictly monotone over the realized inputs, we disable this optimization. | |
| return False |
| is_max = isinstance(node.op, Max) | ||
| input_arg = node.inputs[0] | ||
|
|
||
| if not input_arg.owner: | ||
| return False | ||
|
|
||
| inner_op = input_arg.owner.op | ||
|
|
||
| if not isinstance(inner_op, Elemwise): | ||
| return False | ||
|
|
||
| scalar_op = inner_op.scalar_op | ||
|
|
||
| monotonicity = getattr(scalar_op, "monotonicity", Monotonicity.NONMONOTONIC) | ||
| is_increasing = monotonicity == Monotonicity.INCREASING | ||
| is_decreasing = monotonicity == Monotonicity.DECREASING | ||
|
|
||
| if not (is_increasing or is_decreasing): | ||
| return False | ||
|
|
||
| x = input_arg.owner.inputs[0] | ||
|
|
||
| # Determine new operation based on current op and monotonicity | ||
| if is_max: | ||
| if is_increasing: | ||
| # max(f_inc(x)) -> f_inc(max(x)) | ||
| inner_result = node.op.make_node(x).outputs[0] | ||
| else: # is_decreasing | ||
| # max(f_dec(x)) -> f_dec(min(x)) | ||
| inner_result = Min(axis=node.op.axis)(x) | ||
| else: # is_min | ||
| if is_increasing: | ||
| # min(f_inc(x)) -> f_inc(min(x)) | ||
| inner_result = node.op.make_node(x).outputs[0] | ||
| else: # is_decreasing | ||
| # min(f_dec(x)) -> f_dec(max(x)) | ||
| inner_result = Max(axis=node.op.axis)(x) | ||
|
|
||
| # Apply the monotonic function to the result | ||
| new_output = inner_op.make_node(inner_result).outputs[0] | ||
|
|
||
| copy_stack_trace(node.outputs[0], new_output) | ||
| return [new_output] |
There was a problem hiding this comment.
Same correctness concern applies to moving monotonic functions outside max/min: for functions with restricted domains or FP overflow/NaN behavior, max(f(x)) -> f(max(x)) (and the decreasing variants) can change both the selected element and the final value. This needs additional guards proving inputs are in-domain/finite and that f is strictly order-preserving for the dtype, otherwise the rewrite is unsafe in canonicalization.
| is_max = isinstance(node.op, Max) | |
| input_arg = node.inputs[0] | |
| if not input_arg.owner: | |
| return False | |
| inner_op = input_arg.owner.op | |
| if not isinstance(inner_op, Elemwise): | |
| return False | |
| scalar_op = inner_op.scalar_op | |
| monotonicity = getattr(scalar_op, "monotonicity", Monotonicity.NONMONOTONIC) | |
| is_increasing = monotonicity == Monotonicity.INCREASING | |
| is_decreasing = monotonicity == Monotonicity.DECREASING | |
| if not (is_increasing or is_decreasing): | |
| return False | |
| x = input_arg.owner.inputs[0] | |
| # Determine new operation based on current op and monotonicity | |
| if is_max: | |
| if is_increasing: | |
| # max(f_inc(x)) -> f_inc(max(x)) | |
| inner_result = node.op.make_node(x).outputs[0] | |
| else: # is_decreasing | |
| # max(f_dec(x)) -> f_dec(min(x)) | |
| inner_result = Min(axis=node.op.axis)(x) | |
| else: # is_min | |
| if is_increasing: | |
| # min(f_inc(x)) -> f_inc(min(x)) | |
| inner_result = node.op.make_node(x).outputs[0] | |
| else: # is_decreasing | |
| # min(f_dec(x)) -> f_dec(max(x)) | |
| inner_result = Max(axis=node.op.axis)(x) | |
| # Apply the monotonic function to the result | |
| new_output = inner_op.make_node(inner_result).outputs[0] | |
| copy_stack_trace(node.outputs[0], new_output) | |
| return [new_output] | |
| # NOTE: Disabled optimization. | |
| # Moving monotonic scalar functions outside of Max/Min (e.g., | |
| # max(f(x)) -> f(max(x))) is only sound when we can prove that: | |
| # * the function is strictly order-preserving for the dtype, | |
| # * all inputs are within its domain and finite, and | |
| # * NaN/Inf behaviour is preserved. | |
| # These properties are not currently guaranteed here, so we | |
| # conservatively skip this rewrite. | |
| return False |
| mode = get_default_mode() | ||
|
|
||
| for f in [pt.exp, pt.log1p, pt.sqrt]: | ||
| # Compile the unrewritten and expected graphs | ||
| unrewritten = pt.argmax(f(x), axis=axis) | ||
| expected = pt.argmax(x, axis=axis) | ||
|
|
||
| # Create functions to apply rewrites | ||
| fn_unrewritten = function([x], unrewritten, mode=mode) | ||
| fn_expected = function([x], expected, mode=mode) | ||
|
|
There was a problem hiding this comment.
These tests compile a separate function for each (axis, f) pair (and also compile the expected graph), which can make the rewrite test suite noticeably slower. Consider using tests.unittest_tools.assert_equal_computations/graph comparisons and minimizing compilations (e.g., reuse one compiled function per axis, or compile once and test multiple f via parametrization at the graph level) in line with the repo’s testing guidance.
| # b_ndim tells us whether b's core case is a vector (1) or matrix (2) | ||
| b_ndim = node.op.core_op.b_ndim | ||
|
|
||
| # b_ndim=1: b shape (N,) -> result shape (N,) | ||
| # b_ndim=2: b shape (N, K) -> result shape (N, K) | ||
| b_transposed = b[None, :] if b_ndim == 1 else b.mT | ||
| new_out = (b_transposed / pt.expand_dims(d, -2)).mT | ||
| if b_ndim == 1: | ||
| new_out = new_out.squeeze(-1) |
There was a problem hiding this comment.
The division-by-diagonal construction here duplicates the assume_a == "diagonal" fast-path logic in pytensor.tensor.slinalg.solve (transpose, expand_dims(..., -2), .mT, squeeze). To avoid future divergence/bugfixes being applied in only one place, consider factoring this pattern into a small shared helper (or reusing the existing diagonal-solve path).
Rewrite
solvewith diagonal matricesPartial implementation of #1791.
What was done
Added a graph rewrite
rewrite_solve_diaginpytensor/tensor/rewriting/linalg.pythat detects when the first argument tosolveis the output ofpt.diag(d)(i.e., anAllocDiagnode on the main diagonal) and replaces the expensiveBlockwise(Solve(...))node with elementwise division.For a diagonal matrix
A = diag(d), the linear systemA @ x = bhas the closed-form solutionx = b / d, which avoids the full LU factorisation performed byscipy.linalg.solve.The rewrite handles both
b_ndim=1(vectorb) andb_ndim=2(matrixb):b_ndim=1:solve(diag(d), b)→b / db_ndim=2:solve(diag(d), b, b_ndim=2)→b / d[:, None]The rewrite is registered under
@register_canonicalizeso it fires automatically inFAST_RUNandFAST_COMPILEmodes.Tests were added in
tests/tensor/rewriting/test_linalg.py:test_solve_diag_vector_b— verifies theBlockwise(Solve)node is eliminated and the result matches bothb / dand the unoptimised solve.test_solve_diag_matrix_b— same for the matrixbcase.What remains to be done
1. Handle
eye * scalar/vector/matrixdiagonal patternThe current rewrite only triggers when
ais produced bypt.diag(d)(AllocDiag). A second common pattern ispt.eye(n) * x, wherexcan be a scalar, a vector (broadcast over the diagonal), or a full matrix (element-wise multiply with the identity). This pattern already has precedent in existing rewrites (e.g.rewrite_inv_diag_to_diag_reciprocal,rewrite_det_diag_from_eye_mul). The same diagonal detection logic should be extended to coversolve(eye * x, b).2. Rewrite
dot/matmulwith diagonal matricesIssue #1791 also asks for rewrites for
dotandmatmulwhen one of the operands is diagonal:dot(diag(d), x)→d[:, None] * xdot(x, diag(d))→x * dmatmuland the batchedBlockwisevariants.These rewrites are analogous to
rewrite_solve_diagbut targetDot/MatMulnodes instead ofSolvenodes.