diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index 78016c0f52f71..f7c0cf1b6e1a6 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -630,15 +630,36 @@ fn round_columnar( } } +trait RoundEven: num_traits::Float { + fn round_even(self) -> Self; +} + +impl RoundEven for f32 { + fn round_even(self) -> Self { + self.round_ties_even() + } +} + +impl RoundEven for f64 { + fn round_even(self) -> Self { + self.round_ties_even() + } +} + fn round_float(value: T, decimal_places: i32) -> Result where - T: num_traits::Float, + T: RoundEven, { + if decimal_places == 0 { + return Ok(value.round_even()); + } + let factor = T::from(10_f64.powi(decimal_places)).ok_or_else(|| { ArrowError::ComputeError(format!( "Invalid value for decimal places: {decimal_places}" )) })?; + Ok((value * factor).round() / factor) } @@ -809,6 +830,22 @@ mod test { assert_eq!(floats, &expected); } + #[test] + fn test_round_even_f32_one_input() { + let args: Vec = vec![ + Arc::new(Float32Array::from(vec![2.5, 3.5, -2.5, -3.5])), // input + ]; + + let result = round_arrays(Arc::clone(&args[0]), None) + .expect("failed to initialize function round"); + let floats = + as_float32_array(&result).expect("failed to initialize function round"); + + let expected = Float32Array::from(vec![2.0, 4.0, -2.0, -4.0]); + + assert_eq!(floats, &expected); + } + #[test] fn test_round_f64_one_input() { let args: Vec = vec![ @@ -825,6 +862,22 @@ mod test { assert_eq!(floats, &expected); } + #[test] + fn test_round_even_f64_one_input() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![2.5, 3.5, -2.5, -3.5])), // input + ]; + + let result = round_arrays(Arc::clone(&args[0]), None) + .expect("failed to initialize function round"); + let floats = + as_float64_array(&result).expect("failed to initialize function round"); + + let expected = Float64Array::from(vec![2.0, 4.0, -2.0, -4.0]); + + assert_eq!(floats, &expected); + } + #[test] fn test_round_f32_cast_fail() { let args: Vec = vec![ diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 89ae30e3c047b..a4808cf5f95db 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -918,8 +918,8 @@ select round(a), round(b), round(c) from small_floats; ---- -1 0 -1 -1 NULL NULL +0 0 0 0 0 1 -1 0 0 # round with too large # max Int32 is 2147483647