Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 85 additions & 19 deletions datafusion/functions/src/regex/regexpcount.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,10 @@ where
S: StringArrayType<'a>,
{
let (regex_scalar, is_regex_scalar) = if is_regex_scalar || regex_array.len() == 1 {
(Some(regex_array.value(0)), true)
(
(!regex_array.is_null(0)).then(|| regex_array.value(0)),
true,
)
} else {
(None, false)
};
Expand Down Expand Up @@ -300,7 +303,7 @@ where
match (is_regex_scalar, is_start_scalar, is_flags_scalar) {
(true, true, true) => {
let regex = match regex_scalar {
None | Some("") => {
None => {
return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
}
Some(regex) => regex,
Expand All @@ -317,7 +320,7 @@ where
}
(true, true, false) => {
let regex = match regex_scalar {
None | Some("") => {
None => {
return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
}
Some(regex) => regex,
Expand Down Expand Up @@ -346,7 +349,7 @@ where
}
(true, false, true) => {
let regex = match regex_scalar {
None | Some("") => {
None => {
return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
}
Some(regex) => regex,
Expand All @@ -366,7 +369,7 @@ where
}
(true, false, false) => {
let regex = match regex_scalar {
None | Some("") => {
None => {
return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
}
Some(regex) => regex,
Expand Down Expand Up @@ -411,7 +414,7 @@ where
.zip(regex_array.iter())
.map(|(value, regex)| {
let regex = match regex {
None | Some("") => return Ok(0),
None => return Ok(0),
Some(regex) => regex,
};

Expand Down Expand Up @@ -447,7 +450,7 @@ where
izip!(values.iter(), regex_array.iter(), flags_array.iter())
.map(|(value, regex, flags)| {
let regex = match regex {
None | Some("") => return Ok(0),
None => return Ok(0),
Some(regex) => regex,
};

Expand Down Expand Up @@ -481,7 +484,7 @@ where
izip!(values.iter(), regex_array.iter(), start_array.iter())
.map(|(value, regex, start)| {
let regex = match regex {
None | Some("") => return Ok(0),
None => return Ok(0),
Some(regex) => regex,
};

Expand Down Expand Up @@ -531,7 +534,7 @@ where
)
.map(|(value, regex, start, flags)| {
let regex = match regex {
None | Some("") => return Ok(0),
None => return Ok(0),
Some(regex) => regex,
};

Expand All @@ -551,7 +554,7 @@ fn count_matches(
start: Option<i64>,
) -> Result<i64, ArrowError> {
let value = match value {
None | Some("") => return Ok(0),
None => return Ok(0),
Some(value) => value,
};

Expand All @@ -562,12 +565,23 @@ fn count_matches(
));
}

let char_len = value.chars().count();
let start_index = (start as usize).saturating_sub(1);

if start_index > char_len {
return Ok(0);
}

// Find the byte offset for the start position (1-based character index)
let byte_offset = value
.char_indices()
.nth((start as usize).saturating_sub(1))
.map(|(idx, _)| idx)
.unwrap_or(value.len());
let byte_offset = if start_index == char_len {
value.len()
} else {
value
.char_indices()
.nth(start_index)
.map(|(idx, _)| idx)
.unwrap_or(value.len())
};

// Use string slicing instead of collecting chars into a new String
let find_slice = &value[byte_offset..];
Expand All @@ -589,6 +603,7 @@ mod tests {
#[test]
fn test_regexp_count() {
test_case_sensitive_regexp_count_scalar();
test_case_sensitive_regexp_count_empty_pattern_scalar();
test_case_sensitive_regexp_count_scalar_start();
test_case_insensitive_regexp_count_scalar_flags();
test_case_sensitive_regexp_count_start_scalar_complex();
Expand Down Expand Up @@ -675,6 +690,57 @@ mod tests {
});
}

fn test_case_sensitive_regexp_count_empty_pattern_scalar() {
let values = ["", "abc", "abc"];
let start_positions = [1, 1, 2];
let expected: Vec<i64> = vec![1, 4, 3];

values
.iter()
.zip(start_positions.iter())
.enumerate()
.for_each(|(pos, (&value, &start))| {
let expected = expected.get(pos).cloned();
let start_sv = ScalarValue::Int64(Some(start));

let re = regexp_count_with_scalar_values(&[
ScalarValue::Utf8(Some(value.to_string())),
ScalarValue::Utf8(Some("".to_string())),
start_sv.clone(),
]);
match re {
Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
assert_eq!(v, expected, "regexp_count scalar test failed");
}
_ => panic!("Unexpected result"),
}

let re = regexp_count_with_scalar_values(&[
ScalarValue::LargeUtf8(Some(value.to_string())),
ScalarValue::LargeUtf8(Some("".to_string())),
start_sv.clone(),
]);
match re {
Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
assert_eq!(v, expected, "regexp_count scalar test failed");
}
_ => panic!("Unexpected result"),
}

let re = regexp_count_with_scalar_values(&[
ScalarValue::Utf8View(Some(value.to_string())),
ScalarValue::Utf8View(Some("".to_string())),
start_sv,
]);
match re {
Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
assert_eq!(v, expected, "regexp_count scalar test failed");
}
_ => panic!("Unexpected result"),
}
});
}

fn test_case_sensitive_regexp_count_scalar_start() {
let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
let regex = "abc";
Expand Down Expand Up @@ -792,7 +858,7 @@ mod tests {
let values = A::from(vec!["", "aabca", "abcabc", "abcAbcab", "abcabcAbc"]);
let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);

let expected = Int64Array::from(vec![0, 1, 2, 2, 2]);
let expected = Int64Array::from(vec![1, 1, 2, 2, 2]);

let re = regexp_count_func(&[Arc::new(values), Arc::new(regex)]).unwrap();
assert_eq!(re.as_ref(), &expected);
Expand All @@ -806,7 +872,7 @@ mod tests {
let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
let start = Int64Array::from(vec![1, 2, 3, 4, 5]);

let expected = Int64Array::from(vec![0, 0, 1, 1, 0]);
let expected = Int64Array::from(vec![1, 0, 1, 1, 0]);

let re = regexp_count_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)])
.unwrap();
Expand All @@ -822,7 +888,7 @@ mod tests {
let start = Int64Array::from(vec![1]);
let flags = A::from(vec!["", "i", "", "", "i"]);

let expected = Int64Array::from(vec![0, 1, 2, 2, 3]);
let expected = Int64Array::from(vec![1, 1, 2, 2, 3]);

let re = regexp_count_func(&[
Arc::new(values),
Expand Down Expand Up @@ -910,7 +976,7 @@ mod tests {
let start = Int64Array::from(vec![1, 2, 3, 4, 5]);
let flags = A::from(vec!["", "i", "", "", "i"]);

let expected = Int64Array::from(vec![0, 1, 1, 1, 1]);
let expected = Int64Array::from(vec![1, 1, 1, 1, 1]);

let re = regexp_count_func(&[
Arc::new(values),
Expand Down
25 changes: 25 additions & 0 deletions datafusion/sqllogictest/test_files/regexp/regexp_count.slt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,31 @@ SELECT regexp_count('ABCABCABCABC', 'Abc', 1, 'i');
----
4

query I
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice addition to cover regexp_count('abc', ''). Since this change also affects the start-position and flag handling paths, it would be great to add a couple more SQL-visible cases here as well, like regexp_count('abc', '', 2) and regexp_count('abc', '', 1, 'i'). It may also be worth covering the one-past-end or beyond-end boundary behavior if that is intentional. The Rust unit tests already exercise some of the internals, but adding these to the SLT suite would help lock in the user-facing behavior that motivated this PR.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review! I have added some tests in this slt file. Please take a look again.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can refer to the execution results of PGSQL.

SELECT regexp_count('abc', '');
----
4

query I
SELECT regexp_count('abc', '', 2);
----
3

query I
SELECT regexp_count('abc', '', 1, 'i');
----
4

query I
SELECT regexp_count('abc', '', 4);
----
1

query I
SELECT regexp_count('abc', '', 5);
----
0

statement error
External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based
SELECT regexp_count('123123123123', '123', 0);
Expand Down
Loading