Skip to content

Fix swapped conditions in local_sqrt_sqr rewrite#1922

Open
WHOIM1205 wants to merge 1 commit intopymc-devs:v3from
WHOIM1205:fix-sqrt-sqr-rewrite-swap
Open

Fix swapped conditions in local_sqrt_sqr rewrite#1922
WHOIM1205 wants to merge 1 commit intopymc-devs:v3from
WHOIM1205:fix-sqrt-sqr-rewrite-swap

Conversation

@WHOIM1205
Copy link
Contributor

Fix: Correct swapped logic in local_sqrt_sqr rewrite

Summary

Fixes a numerical correctness bug in pytensor/tensor/rewriting/math.py
where the rewrite rule local_sqrt_sqr had its conditions swapped.

The previous implementation incorrectly transformed:

  • sqrt(sqr(x))switch(x >= 0, x, nan) (should be abs(x))
  • sqr(sqrt(x))abs(x) (should be switch(x >= 0, x, nan))

This caused silent wrong numerical results, especially for negative inputs.


Root Cause

prev_op (inner op) and node_op (outer op) checks were reversed:

  • The branch matching Sqr(Sqrt(x)) returned abs(x)
  • The branch matching Sqrt(Sqr(x)) returned switch(...)

The return values were correct — but attached to the wrong condition.


Fix

Swapped the two isinstance conditions so that:

  • sqrt(sqr(x))abs(x)
  • sqr(sqrt(x))switch(x >= 0, x, nan)

This is a minimal two-line logical correction.


Tests Updated

  • Corrected expected graph structure.
  • Added numerical tests with negative values.
  • Ensured behavior matches NumPy exactly.
  • Prevents future silent regressions of this type.

Impact

  • Restores correct numerical semantics.
  • Eliminates silent nan pollution in common patterns like sqrt(x**2).
  • Ensures PyMC and downstream users get correct magnitude behavior.

Signed-off-by: WHOIM1205 <rathourprateek8@gmail.com>
@WHOIM1205
Copy link
Contributor Author

hey @ricardoV94
Swaps the isinstance checks in local_sqrt_sqr so each rewrite matches the correct expression
Previously
sqrt(sqr(x)) and sqr(sqrt(x)) were transformed into the wrong outputs.
This fixes silent numerical errors and adds proper numerical tests (including negative inputs) to prevent regression

@ricardoV94
Copy link
Member

ricardoV94 commented Mar 4, 2026

Can you point to the original PR here in the comments? Want to see what might have failed in the review process

assert equal_computations([out], [expected])

def test_sqr_sqrt_integer_upcast(self):
f = pytensor.function([x], sqr(sqrt(x)), mode=self.mode)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still wouldn't compile and evaluate the functions, the assert_equal_computations shows what the function does already. An independent test would be to compare against the unoptimized function not an expected numerical value

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants