diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 7123c4d5e60d5..d03e83ae67e6d 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -268,7 +268,14 @@ where S: StringArrayType<'a>, { let (regex_scalar, is_regex_scalar) = if is_regex_scalar || regex_array.len() == 1 { - (Some(regex_array.value(0)), true) + ( + if regex_array.is_null(0) { + None + } else { + Some(regex_array.value(0)) + }, + true, + ) } else { (None, false) }; @@ -300,7 +307,7 @@ where match (is_regex_scalar, is_start_scalar, is_flags_scalar) { (true, true, true) => { let regex = match regex_scalar { - None | Some("") => { + None => { return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))); } Some(regex) => regex, @@ -317,7 +324,7 @@ where } (true, true, false) => { let regex = match regex_scalar { - None | Some("") => { + None => { return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))); } Some(regex) => regex, @@ -346,7 +353,7 @@ where } (true, false, true) => { let regex = match regex_scalar { - None | Some("") => { + None => { return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))); } Some(regex) => regex, @@ -366,7 +373,7 @@ where } (true, false, false) => { let regex = match regex_scalar { - None | Some("") => { + None => { return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))); } Some(regex) => regex, @@ -411,7 +418,7 @@ where .zip(regex_array.iter()) .map(|(value, regex)| { let regex = match regex { - None | Some("") => return Ok(0), + None => return Ok(0), Some(regex) => regex, }; @@ -447,7 +454,7 @@ where izip!(values.iter(), regex_array.iter(), flags_array.iter()) .map(|(value, regex, flags)| { let regex = match regex { - None | Some("") => return Ok(0), + None => return Ok(0), Some(regex) => regex, }; @@ -481,7 +488,7 @@ where izip!(values.iter(), regex_array.iter(), start_array.iter()) .map(|(value, regex, start)| { let regex = match regex { - None | Some("") => return Ok(0), + None => return Ok(0), Some(regex) => regex, }; @@ -531,7 +538,7 @@ where ) .map(|(value, regex, start, flags)| { let regex = match regex { - None | Some("") => return Ok(0), + None => return Ok(0), Some(regex) => regex, }; @@ -590,6 +597,7 @@ mod tests { fn test_regexp_count() { test_case_sensitive_regexp_count_scalar(); test_case_sensitive_regexp_count_scalar_start(); + test_case_sensitive_regexp_count_scalar_empty_pattern(); test_case_insensitive_regexp_count_scalar_flags(); test_case_sensitive_regexp_count_start_scalar_complex(); @@ -719,6 +727,61 @@ mod tests { }); } + fn test_case_sensitive_regexp_count_scalar_empty_pattern() { + let values = ["abc", "abc", ""]; + let regex = ""; + let start = [1, 4, 1]; + let expected: Vec = vec![4, 1, 0]; + + izip!(values.iter(), start.iter()) + .enumerate() + .for_each(|(pos, (&v, &s))| { + let expected = expected.get(pos).cloned(); + + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let re = + regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!( + v, expected, + "regexp_count scalar empty-pattern test failed" + ); + } + _ => panic!("Unexpected result"), + } + + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); + let re = + regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!( + v, expected, + "regexp_count scalar empty-pattern test failed" + ); + } + _ => panic!("Unexpected result"), + } + + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!( + v, expected, + "regexp_count scalar empty-pattern test failed" + ); + } + _ => panic!("Unexpected result"), + } + }); + } + fn test_case_insensitive_regexp_count_scalar_flags() { let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; let regex = "abc"; diff --git a/datafusion/sqllogictest/test_files/regexp/regexp_count.slt b/datafusion/sqllogictest/test_files/regexp/regexp_count.slt index d842a1ee81dfb..0344b58dd293f 100644 --- a/datafusion/sqllogictest/test_files/regexp/regexp_count.slt +++ b/datafusion/sqllogictest/test_files/regexp/regexp_count.slt @@ -26,6 +26,11 @@ SELECT regexp_count('123123123123123', '(12)3'); ---- 5 +query I +SELECT regexp_count('abc', ''); +---- +4 + query I SELECT regexp_count('123123123123', '123', 1); ----