From 568cdb3fc71572d779a96d93036ee0065a7c8cf3 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Wed, 27 May 2026 22:39:21 +0000 Subject: [PATCH] Use vfma directly in TwoProduct paths of double.jl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `dmul`, `dsqu`, `dsqrt`, `ddiv`, `drec` (the `::True` branches, gated on `fma_fast()`) implement the TwoProduct error-free transformation — `Double(z, fma(x, y, -z))` — to extract the rounding error of `x*y` as the low part of a double-double. The identity requires a single- rounded FMA, but the code was calling `vfmsub` / `vfnmadd` (the `vfmadd` family), which lower to `llvm.fmuladd`. The latter has explicit "may fuse or not, optimizer's choice" semantics, and LLVM's middle-end constant-folds the `(x*y) + (-(x*y))` pattern to `0` even on hardware that has FMA. Every double-double error term was being destroyed, which propagated as catastrophic precision loss in downstream `asinh`/`sin`/`cos`/`tan`/`asin`/`acos`/sincos for Float32 inputs — `SLEEFPirates.asinh(-959.98f0)` returned ≈ `-7.62` instead of ≈ `-7.56006` (~10⁵ ULP). Switch the 10 TwoProduct call sites to `vfma` (which lowers to `llvm.fma` and forbids that fold). The `vfmadd` family is left as-is — its "may-fuse" contract is correct for the perf-oriented callers it serves; only the TwoProduct paths actually require guaranteed FMA semantics. Verified on Apple M-series (Julia 1.11.8): - `SLEEFPirates.asinh(-959.98f0) == Base.asinh(-959.98f0)` to Float32 precision. - `SLEEFPirates`'s full accuracy suite: 560/565 pass / 1 fail (`tanh_fast` 1-ULP edge, separate issue) / 4 broken (pre-existing `pow_fast`). Was many fails (`asinh`, `sin`, `cos`, `tan`, `acos`, `cos_sincos`, `sin_sincos`, `asin`). - LoopVectorization `dot.jl` (171/171) and `gemm.jl` (18137/18137) unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/special/double.jl | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/special/double.jl b/src/special/double.jl index 019e758c..c81baedc 100644 --- a/src/special/double.jl +++ b/src/special/double.jl @@ -315,9 +315,19 @@ end end # two-prod-fma +# +# These `::True` branches implement the TwoProduct error-free transformation +# (e.g. `Double(z, fma(x, y, -z))`), which extracts the rounding error of +# `x*y` as the low part of a double-double. The identity requires a +# **single-rounded** FMA; using the may-fuse `vfmsub`/`vfnmadd` lets LLVM +# constant-fold `(x*y) + (-(x*y))` to `0` even on hardware that has FMA, +# which destroys every double-double error term and propagates as wildly +# wrong results in downstream `asinh`/`sin`/`cos`/etc. (e.g. ~10^5 ULP on +# Float32). Call `vfma` (which lowers to `llvm.fma`) directly with an +# explicit negation so the FMA is preserved. @inline function dmul(x::vIEEEFloat, y::vIEEEFloat, ::True) z = (x * y) - Double(z, vfmsub(x, y, z)) + Double(z, vfma(x, y, -z)) end @inline function dmul(x::vIEEEFloat, y::vIEEEFloat, ::False) hx, lx = splitprec(x) @@ -329,7 +339,7 @@ end end @inline function dmul(x::Double{<:vIEEEFloat}, y::vIEEEFloat, ::True) z = (x.hi * y) - Double(z, vfmsub(x.hi, y, z) + x.lo * y) + Double(z, vfma(x.hi, y, -z) + x.lo * y) end @inline function dmul(x::Double{<:vIEEEFloat}, y::vIEEEFloat, ::False) hx, lx = splitprec(x.hi) @@ -341,7 +351,7 @@ end end @inline function dmul(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}, ::True) z = x.hi * y.hi - Double(z, vfmsub(x.hi, y.hi, z) + x.hi * y.lo + x.lo * y.hi) + Double(z, vfma(x.hi, y.hi, -z) + x.hi * y.lo + x.lo * y.hi) end @inline function dmul(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}, ::False) hx, lx = splitprec(x.hi) @@ -361,7 +371,7 @@ end # x^2 @inline function dsqu(x::T, ::True) where {T<:vIEEEFloat} z = x * x - Double(z, vfmsub(x, x, z)) + Double(z, vfma(x, x, -z)) end @inline function dsqu(x::T, ::False) where {T<:vIEEEFloat} hx, lx = splitprec(x) @@ -372,7 +382,7 @@ end end @inline function dsqu(x::Double{T}, ::True) where {T<:vIEEEFloat} z = x.hi * x.hi - Double(z, vfmsub(x.hi, x.hi, z) + (x.hi * (x.lo + x.lo))) + Double(z, vfma(x.hi, x.hi, -z) + (x.hi * (x.lo + x.lo))) end @inline function dsqu(x::Double{T}, ::False) where {T<:vIEEEFloat} hx, lx = splitprec(x.hi) @@ -386,7 +396,7 @@ end # sqrt(x) @inline function dsqrt(x::Double{T}, ::True) where {T<:vIEEEFloat} zhi = @fastmath sqrt(x.hi) - Double(zhi, (x.lo + vfnmadd(zhi, zhi, x.hi)) / (zhi + zhi)) + Double(zhi, (x.lo + vfma(-zhi, zhi, x.hi)) / (zhi + zhi)) end @inline function dsqrt(x::Double{T}, ::False) where {T<:vIEEEFloat} c = @fastmath sqrt(x.hi) @@ -399,7 +409,7 @@ end @inline function ddiv(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}, ::True) invy = inv(y.hi) zhi = (x.hi * invy) - Double(zhi, ((vfnmadd(zhi, y.hi, x.hi) + vfnmadd(zhi, y.lo, x.lo)) * invy)) + Double(zhi, ((vfma(-zhi, y.hi, x.hi) + vfma(-zhi, y.lo, x.lo)) * invy)) end @inline function ddiv(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}, ::False) @ieee begin @@ -412,7 +422,7 @@ end @inline function ddiv(x::vIEEEFloat, y::vIEEEFloat, ::True) ry = inv(y) r = (x * ry) - Double(r, (vfnmadd(r, y, x) * ry)) + Double(r, (vfma(-r, y, x) * ry)) end @inline function ddiv(x::vIEEEFloat, y::vIEEEFloat, ::False) @ieee begin @@ -427,7 +437,7 @@ end # 1/x @inline function drec(x::vIEEEFloat, ::True) zhi = inv(x) - Double(zhi, (vfnmadd(zhi, x, one(eltype(x))) * zhi)) + Double(zhi, (vfma(-zhi, x, one(eltype(x))) * zhi)) end @inline function drec(x::vIEEEFloat, ::False) @ieee begin @@ -439,7 +449,7 @@ end @inline function drec(x::Double{<:vIEEEFloat}, ::True) zhi = inv(x.hi) - Double(zhi, ((vfnmadd(zhi, x.hi, one(eltype(x))) - (zhi * x.lo)) * zhi)) + Double(zhi, ((vfma(-zhi, x.hi, one(eltype(x))) - (zhi * x.lo)) * zhi)) end @inline function drec(x::Double{<:vIEEEFloat}, ::False) @ieee begin