Skip to content

Diagonal solve dot rewrite#1930

Closed
Jasjeet-Singh-S wants to merge 6 commits intopymc-devs:v3from
Jasjeet-Singh-S:diagonal-solve-dot-rewrite
Closed

Diagonal solve dot rewrite#1930
Jasjeet-Singh-S wants to merge 6 commits intopymc-devs:v3from
Jasjeet-Singh-S:diagonal-solve-dot-rewrite

Conversation

@Jasjeet-Singh-S
Copy link

Rewrite solve with diagonal matrices

Partial implementation of #1791.

What was done

Added a graph rewrite rewrite_solve_diag in pytensor/tensor/rewriting/linalg.py that detects when the first argument to solve is the output of pt.diag(d) (i.e., an AllocDiag node on the main diagonal) and replaces the expensive Blockwise(Solve(...)) node with elementwise division.

For a diagonal matrix A = diag(d), the linear system A @ x = b has the closed-form solution x = b / d, which avoids the full LU factorisation performed by scipy.linalg.solve.

The rewrite handles both b_ndim=1 (vector b) and b_ndim=2 (matrix b):

  • b_ndim=1: solve(diag(d), b)b / d
  • b_ndim=2: solve(diag(d), b, b_ndim=2)b / d[:, None]

The rewrite is registered under @register_canonicalize so it fires automatically in FAST_RUN and FAST_COMPILE modes.

Tests were added in tests/tensor/rewriting/test_linalg.py:

  • test_solve_diag_vector_b — verifies the Blockwise(Solve) node is eliminated and the result matches both b / d and the unoptimised solve.
  • test_solve_diag_matrix_b — same for the matrix b case.

What remains to be done

1. Handle eye * scalar/vector/matrix diagonal pattern

The current rewrite only triggers when a is produced by pt.diag(d) (AllocDiag). A second common pattern is pt.eye(n) * x, where x can be a scalar, a vector (broadcast over the diagonal), or a full matrix (element-wise multiply with the identity). This pattern already has precedent in existing rewrites (e.g. rewrite_inv_diag_to_diag_reciprocal, rewrite_det_diag_from_eye_mul). The same diagonal detection logic should be extended to cover solve(eye * x, b).

2. Rewrite dot/matmul with diagonal matrices

Issue #1791 also asks for rewrites for dot and matmul when one of the operands is diagonal:

  • dot(diag(d), x)d[:, None] * x
  • dot(x, diag(d))x * d
  • Same patterns for matmul and the batched Blockwise variants.

These rewrites are analogous to rewrite_solve_diag but target Dot/MatMul nodes instead of Solve nodes.

Implements graph rewrite that eliminates redundant monotonic function applications in argmax/argmin operations. For monotonically increasing functions, rewrites argmax(f(x)) → argmax(x) and argmin(f(x)) → argmin(x). For decreasing functions, flips operations: argmax(f(x)) → argmin(x) and argmin(f(x)) → argmax(x). Includes comprehensive tests.
Copilot AI review requested due to automatic review settings March 3, 2026 18:09
@Jasjeet-Singh-S Jasjeet-Singh-S marked this pull request as draft March 3, 2026 18:13
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a new canonicalization rewrite to optimize solve when the left-hand side is an explicitly constructed diagonal matrix, and also introduces monotonicity metadata + rewrites for argmax/argmin/max/min that leverage monotonic scalar ops.

Changes:

  • Add rewrite_solve_diag to rewrite Blockwise(Solve)(pt.diag(d), b) into elementwise division.
  • Introduce a Monotonicity enum on scalar ops and add canonicalization rewrites for monotonic argmax/argmin/max/min patterns.
  • Add tests covering the new diagonal-solve rewrite and the monotonic argmax/argmin/max/min rewrites.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
pytensor/tensor/rewriting/linalg.py Adds rewrite_solve_diag to replace diagonal Solve with division.
tests/tensor/rewriting/test_linalg.py Adds tests ensuring the diagonal-solve rewrite removes Blockwise(Solve) and matches expected outputs.
pytensor/scalar/basic.py Introduces Monotonicity and annotates several scalar ops with monotonicity metadata.
pytensor/tensor/rewriting/math.py Adds canonicalization rewrites for monotonic argmax/argmin and max/min patterns.
tests/tensor/rewriting/test_math.py Adds tests validating the monotonic argmax/argmin/max/min rewrites and graph structure changes.

Comment on lines +3912 to +3922
@register_canonicalize
@node_rewriter([Argmax])
def local_argmax_argmin_monotonic(fgraph, node):
"""
Optimize argmax/argmin with monotonic functions:
- argmax(f_inc(x)) -> argmax(x) for monotonically increasing f
- argmin(f_inc(x)) -> argmin(x) for monotonically increasing f
- argmax(f_dec(x)) -> argmin(x) for monotonically decreasing f
- argmin(f_dec(x)) -> argmax(x) for monotonically decreasing f
Note: argmin is represented as Argmax(Neg(...)) internally
"""
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description/title are focused on the diagonal solve rewrite, but this file also introduces new monotonicity-based canonicalization rewrites for argmax/argmin/max/min plus a new scalar-op API (Monotonicity). Please either update the PR description to cover these additional changes, or split them into a separate PR to keep review scope aligned.

Copilot uses AI. Check for mistakes.
Comment on lines +3927 to +3971
is_argmin = _is_argmin(node)
argmax_input = node.inputs[0]

# If argmin, skip the Neg wrapper to get to the monotonic function
if is_argmin:
if not argmax_input.owner:
return False
argmax_input = argmax_input.owner.inputs[0] # Skip Neg

if not argmax_input.owner:
return False

inner_op = argmax_input.owner.op

if not isinstance(inner_op, Elemwise):
return False

scalar_op = inner_op.scalar_op

monotonicity = getattr(scalar_op, "monotonicity", Monotonicity.NONMONOTONIC)
is_increasing = monotonicity == Monotonicity.INCREASING
is_decreasing = monotonicity == Monotonicity.DECREASING

if not (is_increasing or is_decreasing):
return False

x = argmax_input.owner.inputs[0]

# Determine new operation based on current op and monotonicity
if is_argmin:
if is_increasing:
# argmin(f_inc(x)) -> argmin(x) = Argmax(Neg(x))
new_output = argmin(x, axis=node.op.axis)
else: # is_decreasing
# argmin(f_dec(x)) -> argmax(x)
new_output = node.op(x)
else: # is argmax
if is_increasing:
# argmax(f_inc(x)) -> argmax(x)
new_output = node.op(x)
else: # is_decreasing
# argmax(f_dec(x)) -> argmin(x) = Argmax(Neg(x))
new_output = argmin(x, axis=node.op.axis)

copy_stack_trace(node.outputs[0], new_output)
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These monotonicity-based rewrites are not semantics-preserving for many scalar ops in floating-point: they can change results in the presence of NaNs, infinities, domain errors (e.g., log, sqrt, arccos), or overflow/underflow that introduces ties (e.g., exp(1000) and exp(1001) both becoming inf). In those cases argmax(f(x))/argmin(f(x)) is not guaranteed to equal argmax(x)/argmin(x) due to different tie-breaking. This rewrite should be constrained to cases where f is proven total + strictly monotone over the realized inputs (or removed).

Suggested change
is_argmin = _is_argmin(node)
argmax_input = node.inputs[0]
# If argmin, skip the Neg wrapper to get to the monotonic function
if is_argmin:
if not argmax_input.owner:
return False
argmax_input = argmax_input.owner.inputs[0] # Skip Neg
if not argmax_input.owner:
return False
inner_op = argmax_input.owner.op
if not isinstance(inner_op, Elemwise):
return False
scalar_op = inner_op.scalar_op
monotonicity = getattr(scalar_op, "monotonicity", Monotonicity.NONMONOTONIC)
is_increasing = monotonicity == Monotonicity.INCREASING
is_decreasing = monotonicity == Monotonicity.DECREASING
if not (is_increasing or is_decreasing):
return False
x = argmax_input.owner.inputs[0]
# Determine new operation based on current op and monotonicity
if is_argmin:
if is_increasing:
# argmin(f_inc(x)) -> argmin(x) = Argmax(Neg(x))
new_output = argmin(x, axis=node.op.axis)
else: # is_decreasing
# argmin(f_dec(x)) -> argmax(x)
new_output = node.op(x)
else: # is argmax
if is_increasing:
# argmax(f_inc(x)) -> argmax(x)
new_output = node.op(x)
else: # is_decreasing
# argmax(f_dec(x)) -> argmin(x) = Argmax(Neg(x))
new_output = argmin(x, axis=node.op.axis)
copy_stack_trace(node.outputs[0], new_output)
# NOTE:
# Monotonicity-based rewrites of argmax/argmin are not generally
# semantics-preserving in floating-point arithmetic: NaNs, infinities,
# domain errors, and over/underflow-induced ties can change which index
# is returned. Until we can prove that the wrapped scalar op is total and
# strictly monotone over the realized inputs, we disable this optimization.
return False

Copilot uses AI. Check for mistakes.
Comment on lines +3988 to +4030
is_max = isinstance(node.op, Max)
input_arg = node.inputs[0]

if not input_arg.owner:
return False

inner_op = input_arg.owner.op

if not isinstance(inner_op, Elemwise):
return False

scalar_op = inner_op.scalar_op

monotonicity = getattr(scalar_op, "monotonicity", Monotonicity.NONMONOTONIC)
is_increasing = monotonicity == Monotonicity.INCREASING
is_decreasing = monotonicity == Monotonicity.DECREASING

if not (is_increasing or is_decreasing):
return False

x = input_arg.owner.inputs[0]

# Determine new operation based on current op and monotonicity
if is_max:
if is_increasing:
# max(f_inc(x)) -> f_inc(max(x))
inner_result = node.op.make_node(x).outputs[0]
else: # is_decreasing
# max(f_dec(x)) -> f_dec(min(x))
inner_result = Min(axis=node.op.axis)(x)
else: # is_min
if is_increasing:
# min(f_inc(x)) -> f_inc(min(x))
inner_result = node.op.make_node(x).outputs[0]
else: # is_decreasing
# min(f_dec(x)) -> f_dec(max(x))
inner_result = Max(axis=node.op.axis)(x)

# Apply the monotonic function to the result
new_output = inner_op.make_node(inner_result).outputs[0]

copy_stack_trace(node.outputs[0], new_output)
return [new_output]
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same correctness concern applies to moving monotonic functions outside max/min: for functions with restricted domains or FP overflow/NaN behavior, max(f(x)) -> f(max(x)) (and the decreasing variants) can change both the selected element and the final value. This needs additional guards proving inputs are in-domain/finite and that f is strictly order-preserving for the dtype, otherwise the rewrite is unsafe in canonicalization.

Suggested change
is_max = isinstance(node.op, Max)
input_arg = node.inputs[0]
if not input_arg.owner:
return False
inner_op = input_arg.owner.op
if not isinstance(inner_op, Elemwise):
return False
scalar_op = inner_op.scalar_op
monotonicity = getattr(scalar_op, "monotonicity", Monotonicity.NONMONOTONIC)
is_increasing = monotonicity == Monotonicity.INCREASING
is_decreasing = monotonicity == Monotonicity.DECREASING
if not (is_increasing or is_decreasing):
return False
x = input_arg.owner.inputs[0]
# Determine new operation based on current op and monotonicity
if is_max:
if is_increasing:
# max(f_inc(x)) -> f_inc(max(x))
inner_result = node.op.make_node(x).outputs[0]
else: # is_decreasing
# max(f_dec(x)) -> f_dec(min(x))
inner_result = Min(axis=node.op.axis)(x)
else: # is_min
if is_increasing:
# min(f_inc(x)) -> f_inc(min(x))
inner_result = node.op.make_node(x).outputs[0]
else: # is_decreasing
# min(f_dec(x)) -> f_dec(max(x))
inner_result = Max(axis=node.op.axis)(x)
# Apply the monotonic function to the result
new_output = inner_op.make_node(inner_result).outputs[0]
copy_stack_trace(node.outputs[0], new_output)
return [new_output]
# NOTE: Disabled optimization.
# Moving monotonic scalar functions outside of Max/Min (e.g.,
# max(f(x)) -> f(max(x))) is only sound when we can prove that:
# * the function is strictly order-preserving for the dtype,
# * all inputs are within its domain and finite, and
# * NaN/Inf behaviour is preserved.
# These properties are not currently guaranteed here, so we
# conservatively skip this rewrite.
return False

Copilot uses AI. Check for mistakes.
Comment on lines +5036 to +5046
mode = get_default_mode()

for f in [pt.exp, pt.log1p, pt.sqrt]:
# Compile the unrewritten and expected graphs
unrewritten = pt.argmax(f(x), axis=axis)
expected = pt.argmax(x, axis=axis)

# Create functions to apply rewrites
fn_unrewritten = function([x], unrewritten, mode=mode)
fn_expected = function([x], expected, mode=mode)

Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests compile a separate function for each (axis, f) pair (and also compile the expected graph), which can make the rewrite test suite noticeably slower. Consider using tests.unittest_tools.assert_equal_computations/graph comparisons and minimizing compilations (e.g., reuse one compiled function per axis, or compile once and test multiple f via parametrization at the graph level) in line with the repo’s testing guidance.

Copilot generated this review using guidance from repository custom instructions.
Comment on lines +1174 to +1182
# b_ndim tells us whether b's core case is a vector (1) or matrix (2)
b_ndim = node.op.core_op.b_ndim

# b_ndim=1: b shape (N,) -> result shape (N,)
# b_ndim=2: b shape (N, K) -> result shape (N, K)
b_transposed = b[None, :] if b_ndim == 1 else b.mT
new_out = (b_transposed / pt.expand_dims(d, -2)).mT
if b_ndim == 1:
new_out = new_out.squeeze(-1)
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The division-by-diagonal construction here duplicates the assume_a == "diagonal" fast-path logic in pytensor.tensor.slinalg.solve (transpose, expand_dims(..., -2), .mT, squeeze). To avoid future divergence/bugfixes being applied in only one place, consider factoring this pattern into a small shared helper (or reusing the existing diagonal-solve path).

Copilot uses AI. Check for mistakes.
@Jasjeet-Singh-S Jasjeet-Singh-S deleted the diagonal-solve-dot-rewrite branch March 4, 2026 08:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants