diff --git a/docs/source/user-guide/latest/compatibility.md b/docs/source/user-guide/latest/compatibility.md index 1ed805513b..c3860dcdb8 100644 --- a/docs/source/user-guide/latest/compatibility.md +++ b/docs/source/user-guide/latest/compatibility.md @@ -60,9 +60,6 @@ the [Comet Supported Expressions Guide](expressions.md) for more information on ### Array Expressions -- **ArraysOverlap**: Inconsistent behavior when arrays contain NULL values. - [#3645](https://github.com/apache/datafusion-comet/issues/3645), - [#2036](https://github.com/apache/datafusion-comet/issues/2036) - **ArrayUnion**: Sorts input arrays before performing the union, while Spark preserves the order of the first array and appends unique elements from the second. [#3644](https://github.com/apache/datafusion-comet/issues/3644) diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index 16c0823475..2194a0b6f4 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -246,7 +246,7 @@ Comet supports using the following aggregate functions within window contexts wi | ArrayRemove | Yes | | | ArrayRepeat | No | | | ArrayUnion | No | Behaves differently than spark. Comet sorts the input arrays before performing the union, while Spark preserves the order of the first array and appends unique elements from the second. | -| ArraysOverlap | No | | +| ArraysOverlap | Yes | | | CreateArray | Yes | | | ElementAt | Yes | Input must be an array. Map inputs are not supported. | | Flatten | Yes | | diff --git a/native/spark-expr/src/array_funcs/arrays_overlap.rs b/native/spark-expr/src/array_funcs/arrays_overlap.rs new file mode 100644 index 0000000000..094e25c365 --- /dev/null +++ b/native/spark-expr/src/array_funcs/arrays_overlap.rs @@ -0,0 +1,374 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Spark-compatible `arrays_overlap` with correct SQL three-valued null logic. +//! +//! DataFusion's `array_has_any` uses `RowConverter` for element comparison, which +//! treats NULL == NULL as true (grouping semantics). Spark's `arrays_overlap` uses +//! SQL equality where NULL == NULL is unknown (null). This implementation correctly +//! returns: +//! - true if any non-null element appears in both arrays +//! - null if no definite overlap but either array contains null elements +//! - false if no overlap and neither array contains null elements + +use arrow::array::{Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar}; +use arrow::compute::kernels::cmp::eq; +use arrow::datatypes::{DataType, FieldRef}; +use datafusion::common::{exec_err, utils::take_function_args, Result, ScalarValue}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkArraysOverlap { + signature: Signature, +} + +impl Default for SparkArraysOverlap { + fn default() -> Self { + Self::new() + } +} + +impl SparkArraysOverlap { + pub fn new() -> Self { + Self { + signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkArraysOverlap { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "spark_arrays_overlap" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn return_field_from_args( + &self, + _args: datafusion::logical_expr::ReturnFieldArgs, + ) -> Result { + Ok(Arc::new(arrow::datatypes::Field::new( + self.name(), + DataType::Boolean, + true, + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [left, right] = take_function_args(self.name(), &args.args)?; + + // Return null if either input is a null scalar + if let ColumnarValue::Scalar(s) = &left { + if s.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + } + } + if let ColumnarValue::Scalar(s) = &right { + if s.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + } + } + + match (left, right) { + (ColumnarValue::Array(left_arr), ColumnarValue::Array(right_arr)) => { + let result = match (left_arr.data_type(), right_arr.data_type()) { + (DataType::List(_), DataType::List(_)) => arrays_overlap_list::( + left_arr.as_any().downcast_ref().unwrap(), + right_arr.as_any().downcast_ref().unwrap(), + )?, + (DataType::LargeList(_), DataType::LargeList(_)) => arrays_overlap_list::( + left_arr.as_any().downcast_ref().unwrap(), + right_arr.as_any().downcast_ref().unwrap(), + )?, + (l, r) => { + return exec_err!( + "spark_arrays_overlap does not support types '{l}' and '{r}'" + ) + } + }; + Ok(ColumnarValue::Array(result)) + } + (left, right) => { + // Handle scalar inputs by converting to arrays + let left_arr = left.to_array(1)?; + let right_arr = right.to_array(1)?; + let result = match (left_arr.data_type(), right_arr.data_type()) { + (DataType::List(_), DataType::List(_)) => arrays_overlap_list::( + left_arr.as_any().downcast_ref().unwrap(), + right_arr.as_any().downcast_ref().unwrap(), + )?, + (DataType::LargeList(_), DataType::LargeList(_)) => arrays_overlap_list::( + left_arr.as_any().downcast_ref().unwrap(), + right_arr.as_any().downcast_ref().unwrap(), + )?, + (l, r) => { + return exec_err!( + "spark_arrays_overlap does not support types '{l}' and '{r}'" + ) + } + }; + let scalar = ScalarValue::try_from_array(&result, 0)?; + Ok(ColumnarValue::Scalar(scalar)) + } + } + } +} + +/// Spark-compatible arrays_overlap with SQL three-valued null logic. +/// +/// For each row, compares elements of two list arrays and returns: +/// - null if either array is null +/// - true if any non-null element appears in both arrays +/// - null if no definite overlap but either array contains null elements +/// - false otherwise +fn arrays_overlap_list( + left: &GenericListArray, + right: &GenericListArray, +) -> Result { + let len = left.len(); + let mut builder = BooleanArray::builder(len); + + for i in 0..len { + if left.is_null(i) || right.is_null(i) { + builder.append_null(); + continue; + } + + let left_values = left.value(i); + let right_values = right.value(i); + + // Empty array cannot overlap + if left_values.is_empty() || right_values.is_empty() { + builder.append_value(false); + continue; + } + + // DataFusion's make_array(NULL) produces a List with NullArray values. + // NullArray means all elements are null by definition. + if left_values.data_type() == &DataType::Null || right_values.data_type() == &DataType::Null + { + builder.append_null(); + continue; + } + + let mut found_overlap = false; + + // Ensure smaller array is on the probe side for fewer broadcast eq calls + let (probe, search) = if left_values.len() <= right_values.len() { + (&left_values, &right_values) + } else { + (&right_values, &left_values) + }; + + // For each non-null element in probe, broadcast eq against the full search array. + // One kernel call per probe element: O(p * s) total work but with vectorized inner loop. + for pi in 0..probe.len() { + if probe.is_null(pi) { + continue; + } + let scalar = Scalar::new(probe.slice(pi, 1)); + let eq_result = eq(search, &scalar)?; + if eq_result.true_count() > 0 { + found_overlap = true; + break; + } + } + + if found_overlap { + builder.append_value(true); + } else if left_values.null_count() > 0 || right_values.null_count() > 0 { + builder.append_null(); + } else { + builder.append_value(false); + } + } + + Ok(Arc::new(builder.finish())) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, ListArray}; + use arrow::buffer::{NullBuffer, OffsetBuffer}; + use arrow::datatypes::Field; + + fn make_list_array( + values: &Int32Array, + offsets: &[i32], + nulls: Option, + ) -> ListArray { + ListArray::new( + Arc::new(Field::new("item", DataType::Int32, true)), + OffsetBuffer::new(offsets.to_vec().into()), + Arc::new(values.clone()), + nulls, + ) + } + + #[test] + fn test_basic_overlap() -> Result<()> { + // [1, 2, 3] vs [3, 4, 5] => true + let left = make_list_array(&Int32Array::from(vec![1, 2, 3]), &[0, 3], None); + let right = make_list_array(&Int32Array::from(vec![3, 4, 5]), &[0, 3], None); + + let result = arrays_overlap_list::(&left, &right)?; + let result = result.as_any().downcast_ref::().unwrap(); + assert!(result.value(0)); + assert!(result.is_valid(0)); + Ok(()) + } + + #[test] + fn test_no_overlap() -> Result<()> { + // [1, 2] vs [3, 4] => false + let left = make_list_array(&Int32Array::from(vec![1, 2]), &[0, 2], None); + let right = make_list_array(&Int32Array::from(vec![3, 4]), &[0, 2], None); + + let result = arrays_overlap_list::(&left, &right)?; + let result = result.as_any().downcast_ref::().unwrap(); + assert!(!result.value(0)); + assert!(result.is_valid(0)); + Ok(()) + } + + #[test] + fn test_null_only_overlap() -> Result<()> { + // [1, NULL] vs [NULL, 2] => null (no definite overlap, but nulls present) + let left = make_list_array(&Int32Array::from(vec![Some(1), None]), &[0, 2], None); + let right = make_list_array(&Int32Array::from(vec![None, Some(2)]), &[0, 2], None); + + let result = arrays_overlap_list::(&left, &right)?; + let result = result.as_any().downcast_ref::().unwrap(); + assert!(result.is_null(0)); + Ok(()) + } + + #[test] + fn test_null_with_overlap() -> Result<()> { + // [1, NULL] vs [1, 2] => true (definite overlap on 1) + let left = make_list_array(&Int32Array::from(vec![Some(1), None]), &[0, 2], None); + let right = make_list_array(&Int32Array::from(vec![1, 2]), &[0, 2], None); + + let result = arrays_overlap_list::(&left, &right)?; + let result = result.as_any().downcast_ref::().unwrap(); + assert!(result.value(0)); + assert!(result.is_valid(0)); + Ok(()) + } + + #[test] + fn test_empty_array() -> Result<()> { + // [1, NULL, 3] vs [] => false + let left = make_list_array( + &Int32Array::from(vec![Some(1), None, Some(3)]), + &[0, 3], + None, + ); + let right = make_list_array(&Int32Array::from(Vec::::new()), &[0, 0], None); + + let result = arrays_overlap_list::(&left, &right)?; + let result = result.as_any().downcast_ref::().unwrap(); + assert!(!result.value(0)); + assert!(result.is_valid(0)); + Ok(()) + } + + #[test] + fn test_null_array() -> Result<()> { + // NULL vs [1, 2] => null + let left = make_list_array( + &Int32Array::from(Vec::::new()), + &[0, 0], + Some(NullBuffer::from(vec![false])), + ); + let right = make_list_array(&Int32Array::from(vec![1, 2]), &[0, 2], None); + + let result = arrays_overlap_list::(&left, &right)?; + let result = result.as_any().downcast_ref::().unwrap(); + assert!(result.is_null(0)); + Ok(()) + } + + #[test] + fn test_both_null_elements() -> Result<()> { + // [NULL] vs [NULL] => null + let left = make_list_array(&Int32Array::from(vec![None::]), &[0, 1], None); + let right = make_list_array(&Int32Array::from(vec![None::]), &[0, 1], None); + + let result = arrays_overlap_list::(&left, &right)?; + let result = result.as_any().downcast_ref::().unwrap(); + assert!(result.is_null(0)); + Ok(()) + } + + #[test] + fn test_both_null_elements_via_null_array() -> Result<()> { + // Simulate what DataFusion's make_array(NULL) produces: List with NullArray values + use arrow::array::NullArray; + + let null_values = Arc::new(NullArray::new(1)) as ArrayRef; + let null_field = Arc::new(Field::new("item", DataType::Null, true)); + let left = ListArray::new( + Arc::clone(&null_field), + OffsetBuffer::new(vec![0, 1].into()), + Arc::clone(&null_values), + None, + ); + let right = ListArray::new( + null_field, + OffsetBuffer::new(vec![0, 1].into()), + null_values, + None, + ); + + let result = arrays_overlap_list::(&left, &right)?; + let result = result.as_any().downcast_ref::().unwrap(); + assert!( + result.is_null(0), + "Expected null for [NULL] vs [NULL] (NullArray representation), got {:?}", + result + ); + Ok(()) + } + + #[test] + fn test_one_null_element_no_overlap() -> Result<()> { + // [3, NULL] vs [1, 2] => null + let left = make_list_array(&Int32Array::from(vec![Some(3), None]), &[0, 2], None); + let right = make_list_array(&Int32Array::from(vec![1, 2]), &[0, 2], None); + + let result = arrays_overlap_list::(&left, &right)?; + let result = result.as_any().downcast_ref::().unwrap(); + assert!(result.is_null(0)); + Ok(()) + } +} diff --git a/native/spark-expr/src/array_funcs/mod.rs b/native/spark-expr/src/array_funcs/mod.rs index 2bd1b9631b..55e0d4a22e 100644 --- a/native/spark-expr/src/array_funcs/mod.rs +++ b/native/spark-expr/src/array_funcs/mod.rs @@ -17,12 +17,14 @@ mod array_compact; mod array_insert; +mod arrays_overlap; mod get_array_struct_fields; mod list_extract; mod size; pub use array_compact::SparkArrayCompact; pub use array_insert::ArrayInsert; +pub use arrays_overlap::SparkArraysOverlap; pub use get_array_struct_fields::GetArrayStructFields; pub use list_extract::ListExtract; pub use size::{spark_size, SparkSizeFunc}; diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 9c91bb69c9..8911dda1cc 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -23,8 +23,8 @@ use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, - spark_unscaled_value, EvalMode, SparkArrayCompact, SparkContains, SparkDateDiff, - SparkDateTrunc, SparkMakeDate, SparkSizeFunc, + spark_unscaled_value, EvalMode, SparkArrayCompact, SparkArraysOverlap, SparkContains, + SparkDateDiff, SparkDateTrunc, SparkMakeDate, SparkSizeFunc, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -197,6 +197,7 @@ pub fn create_comet_physical_fun_with_eval_mode( fn all_scalar_functions() -> Vec> { vec![ Arc::new(ScalarUDF::new_from_impl(SparkArrayCompact::default())), + Arc::new(ScalarUDF::new_from_impl(SparkArraysOverlap::default())), Arc::new(ScalarUDF::new_from_impl(SparkContains::default())), Arc::new(ScalarUDF::new_from_impl(SparkDateDiff::default())), Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())), diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index f107d5b309..7cc818a40c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -179,23 +179,15 @@ object CometArrayMin extends CometExpressionSerde[ArrayMin] { } object CometArraysOverlap extends CometExpressionSerde[ArraysOverlap] { - - override def getSupportLevel(expr: ArraysOverlap): SupportLevel = - Incompatible( - Some( - "Inconsistent behavior with NULL values" + - " (https://github.com/apache/datafusion-comet/issues/3645)" + - " (https://github.com/apache/datafusion-comet/issues/2036)")) - override def convert( expr: ArraysOverlap, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { - val leftArrayExprProto = exprToProto(expr.children.head, inputs, binding) - val rightArrayExprProto = exprToProto(expr.children(1), inputs, binding) + val leftArrayExprProto = exprToProto(expr.left, inputs, binding) + val rightArrayExprProto = exprToProto(expr.right, inputs, binding) val arraysOverlapScalarExpr = scalarFunctionExprToProtoWithReturnType( - "array_has_any", + "spark_arrays_overlap", BooleanType, false, leftArrayExprProto, diff --git a/spark/src/test/resources/sql-tests/expressions/array/arrays_overlap.sql b/spark/src/test/resources/sql-tests/expressions/array/arrays_overlap.sql index 27d28a7402..88cf38d211 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/arrays_overlap.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/arrays_overlap.sql @@ -15,7 +15,6 @@ -- specific language governing permissions and limitations -- under the License. --- Config: spark.comet.expression.ArraysOverlap.allowIncompatible=true -- ConfigMatrix: parquet.enable.dictionary=false,true statement @@ -24,11 +23,11 @@ CREATE TABLE test_arrays_overlap(a array, b array) USING parquet statement INSERT INTO test_arrays_overlap VALUES (array(1, 2, 3), array(3, 4, 5)), (array(1, 2), array(3, 4)), (array(), array(1)), (NULL, array(1)), (array(1, NULL), array(NULL, 2)) -query ignore(https://github.com/apache/datafusion-comet/issues/3645) +query SELECT arrays_overlap(a, b) FROM test_arrays_overlap -- column + literal -query ignore(https://github.com/apache/datafusion-comet/issues/3645) +query SELECT arrays_overlap(a, array(3, 4, 5)) FROM test_arrays_overlap -- literal + column diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index bb519492db..005df7ec56 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -23,7 +23,7 @@ import scala.util.Random import org.apache.hadoop.fs.Path import org.apache.spark.sql.CometTestBase -import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayDistinct, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayRepeat, ArraysOverlap, ArrayUnion} +import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayDistinct, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayRepeat, ArrayUnion} import org.apache.spark.sql.catalyst.expressions.{ArrayContains, ArrayRemove} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.functions._ @@ -545,22 +545,74 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } test("arrays_overlap") { - withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[ArraysOverlap]) -> "true") { - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - withTempView("t1") { - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1") - checkSparkAnswerAndOperator(sql( - "SELECT arrays_overlap(array(_2, _3, _4), array(_3, _4)) from t1 where _2 is not null")) - checkSparkAnswerAndOperator(sql( - "SELECT arrays_overlap(array('a', null, cast(_1 as string)), array('b', cast(_1 as string), cast(_2 as string))) from t1 where _1 is not null")) - checkSparkAnswerAndOperator(sql( - "SELECT arrays_overlap(array('a', null), array('b', null)) from t1 where _1 is not null")) - checkSparkAnswerAndOperator(spark.sql( - "SELECT arrays_overlap((CASE WHEN _2 =_3 THEN array(_6, _7) END), array(_6, _7)) FROM t1")); - } + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator(sql( + "SELECT arrays_overlap(array(_2, _3, _4), array(_3, _4)) from t1 where _2 is not null")) + checkSparkAnswerAndOperator(sql( + "SELECT arrays_overlap(array('a', null, cast(_1 as string)), array('b', cast(_1 as string), cast(_2 as string))) from t1 where _1 is not null")) + checkSparkAnswerAndOperator(sql( + "SELECT arrays_overlap(array('a', null), array('b', null)) from t1 where _1 is not null")) + checkSparkAnswerAndOperator(spark.sql( + "SELECT arrays_overlap((CASE WHEN _2 =_3 THEN array(_6, _7) END), array(_6, _7)) FROM t1")); + } + } + } + } + + test("arrays_overlap - null handling behavior verification") { + withSQLConf( + "spark.sql.optimizer.excludedRules" -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { + withTable("t") { + sql("create table t using parquet as select CAST(NULL as array) a1 from range(1)") + val data = Seq( + "array(1, 2, 3)", + "array(3, 4, 5)", + "array(1, 2)", + "array(3, 4)", + "array(1, NULL, 3)", + "array(4, 5)", + "array(1, 4)", + "array(1, NULL)", + "array(2, NULL)", + "array(NULL, 2)", + "array(1)", + "array(2)", + "array()", + "array(NULL)", + "array(NULL, NULL)", + "a1") + for (y <- data; x <- data) { + checkSparkAnswerAndOperator(sql(s"SELECT arrays_overlap($y, $x) from t")) + } + } + } + } + + test("arrays_overlap - nested array null handling behavior verification") { + withSQLConf( + "spark.sql.optimizer.excludedRules" -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { + withTable("t") { + sql( + "create table t using parquet as select CAST(NULL as array>) a1 from range(1)") + val data = Seq( + "array(array(1, 2), array(3, 4))", + "array(array(1, 2), array(5, 6))", + "array(array(1, 2))", + "array(array(3, 4))", + "array(array(1, NULL))", + "array(array(NULL, 2))", + "array(array(NULL))", + "array(CAST(NULL as array))", + "array(array(1, 2), CAST(NULL as array))", + "array()", + "a1") + for (y <- data; x <- data) { + checkSparkAnswerAndOperator(sql(s"SELECT arrays_overlap($y, $x) from t")) } } }