diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 5c47d6b7c566..c6e27b2603d9 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -23,7 +23,7 @@ use datafusion_expr::{Expr, Filter, Operator}; use crate::optimizer::ApplyOrder; use datafusion_common::tree_node::Transformed; -use datafusion_expr::expr::{BinaryExpr, Cast, TryCast}; +use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; use std::sync::Arc; /// @@ -298,6 +298,23 @@ fn extract_non_nullable_columns( right_schema, false, ), + // IN list and BETWEEN are null-rejecting on the input expression: + // if the input column is NULL, the result is NULL (filtered out), + // regardless of whether the list/range contains NULLs. + Expr::InList(InList { expr, .. }) => extract_non_nullable_columns( + expr, + non_nullable_cols, + left_schema, + right_schema, + false, + ), + Expr::Between(between) => extract_non_nullable_columns( + &between.expr, + non_nullable_cols, + left_schema, + right_schema, + false, + ), _ => {} } } @@ -309,6 +326,7 @@ mod tests { use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; use arrow::datatypes::DataType; + use datafusion_common::ScalarValue; use datafusion_expr::{ Operator::{And, Or}, binary_expr, cast, col, lit, @@ -436,6 +454,221 @@ mod tests { ") } + #[test] + fn eliminate_left_with_in_list() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // t2.b IN (1, 2, 3) rejects nulls — if t2.b is NULL the IN returns + // NULL which is filtered out. So Left Join should become Inner Join. + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(col("t2.b").in_list(vec![lit(1u32), lit(2u32), lit(3u32)], false))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b IN ([UInt32(1), UInt32(2), UInt32(3)]) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn eliminate_left_with_in_list_containing_null() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // IN list with NULL still rejects null input columns: + // if t2.b is NULL, NULL IN (1, NULL) evaluates to NULL, which is filtered out + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter( + col("t2.b") + .in_list(vec![lit(1u32), lit(ScalarValue::UInt32(None))], false), + )? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b IN ([UInt32(1), UInt32(NULL)]) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn eliminate_left_with_not_in_list() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // NOT IN also rejects nulls: if t2.b is NULL, NOT (NULL IN (...)) + // evaluates to NULL, which is filtered out + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(col("t2.b").in_list(vec![lit(1u32), lit(2u32)], true))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b NOT IN ([UInt32(1), UInt32(2)]) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn eliminate_left_with_between() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // BETWEEN rejects nulls: if t2.b is NULL, NULL BETWEEN 1 AND 10 + // evaluates to NULL, which is filtered out + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(col("t2.b").between(lit(1u32), lit(10u32)))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b BETWEEN UInt32(1) AND UInt32(10) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn eliminate_right_with_between() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // Right join: filter on left (nullable) side with BETWEEN should convert to Inner + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Right, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(col("t1.b").between(lit(1u32), lit(10u32)))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t1.b BETWEEN UInt32(1) AND UInt32(10) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn eliminate_full_with_between() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // Full join with BETWEEN on both sides should become Inner + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Full, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(binary_expr( + col("t1.b").between(lit(1u32), lit(10u32)), + And, + col("t2.b").between(lit(5u32), lit(20u32)), + ))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t1.b BETWEEN UInt32(1) AND UInt32(10) AND t2.b BETWEEN UInt32(5) AND UInt32(20) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn eliminate_full_with_in_list() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // Full join with IN filters on both sides should become Inner + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Full, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(binary_expr( + col("t1.b").in_list(vec![lit(1u32), lit(2u32)], false), + And, + col("t2.b").in_list(vec![lit(3u32), lit(4u32)], false), + ))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t1.b IN ([UInt32(1), UInt32(2)]) AND t2.b IN ([UInt32(3), UInt32(4)]) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn no_eliminate_left_with_in_list_or_is_null() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // WHERE (t2.b IN (1, 2)) OR (t2.b IS NULL) + // The OR with IS NULL makes the predicate null-tolerant: + // when t2.b is NULL, IS NULL returns true, so the whole OR is true. + // The outer join must be preserved. + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(binary_expr( + col("t2.b").in_list(vec![lit(1u32), lit(2u32)], false), + Or, + col("t2.b").is_null(), + ))? + .build()?; + + // Should NOT be converted to Inner — OR with IS NULL preserves null rows + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b IN ([UInt32(1), UInt32(2)]) OR t2.b IS NULL + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + #[test] fn eliminate_full_with_type_cast() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; diff --git a/datafusion/sqllogictest/test_files/eliminate_outer_join.slt b/datafusion/sqllogictest/test_files/eliminate_outer_join.slt new file mode 100644 index 000000000000..a44d9af77cfb --- /dev/null +++ b/datafusion/sqllogictest/test_files/eliminate_outer_join.slt @@ -0,0 +1,212 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Test EliminateOuterJoin rule: outer joins with null-rejecting +# predicates (IN list, BETWEEN) should be converted to inner joins. + +statement ok +create table t1(a int, b int, c varchar); + +statement ok +create table t2(x int, y int, z varchar); + +statement ok +insert into t1 values (1, 10, 'a'), (2, 20, 'b'), (3, 30, 'c'), (null, 40, 'd'); + +statement ok +insert into t2 values (1, 100, 'p'), (2, 200, 'q'), (null, 300, 'r'); + +statement ok +set datafusion.explain.logical_plan_only = true; + +### +### IN list tests +### + +# LEFT JOIN + WHERE t2.x IN (...) -> INNER JOIN +# IN is null-rejecting: if t2.x is NULL, IN returns NULL (filtered out). +# After conversion to INNER, PushDownFilter infers the predicate to both sides. +query TT +explain select * from t1 left join t2 on t1.a = t2.x where t2.x in (1, 2, 3); +---- +logical_plan +01)Inner Join: t1.a = t2.x +02)--Filter: t1.a = Int32(1) OR t1.a = Int32(2) OR t1.a = Int32(3) +03)----TableScan: t1 projection=[a, b, c] +04)--Filter: t2.x = Int32(1) OR t2.x = Int32(2) OR t2.x = Int32(3) +05)----TableScan: t2 projection=[x, y, z] + +# Verify result correctness +query IITIIT rowsort +select * from t1 left join t2 on t1.a = t2.x where t2.x in (1, 2, 3); +---- +1 10 a 1 100 p +2 20 b 2 200 q + +# RIGHT JOIN + WHERE t1.a IN (...) -> INNER JOIN +query TT +explain select * from t1 right join t2 on t1.a = t2.x where t1.a in (1, 2); +---- +logical_plan +01)Inner Join: t1.a = t2.x +02)--Filter: t1.a = Int32(1) OR t1.a = Int32(2) +03)----TableScan: t1 projection=[a, b, c] +04)--Filter: t2.x = Int32(1) OR t2.x = Int32(2) +05)----TableScan: t2 projection=[x, y, z] + +query IITIIT rowsort +select * from t1 right join t2 on t1.a = t2.x where t1.a in (1, 2); +---- +1 10 a 1 100 p +2 20 b 2 200 q + +# FULL JOIN + WHERE on both sides -> INNER JOIN +query TT +explain select * from t1 full join t2 on t1.a = t2.x where t1.a in (1, 2) and t2.x in (1, 2); +---- +logical_plan +01)Inner Join: t1.a = t2.x +02)--Filter: t1.a = Int32(1) OR t1.a = Int32(2) +03)----TableScan: t1 projection=[a, b, c] +04)--Filter: t2.x = Int32(1) OR t2.x = Int32(2) +05)----TableScan: t2 projection=[x, y, z] + +query IITIIT rowsort +select * from t1 full join t2 on t1.a = t2.x where t1.a in (1, 2) and t2.x in (1, 2); +---- +1 10 a 1 100 p +2 20 b 2 200 q + +# IN list with NULL in the list — still null-rejecting on input column +query TT +explain select * from t1 left join t2 on t1.a = t2.x where t2.x in (1, 2, null); +---- +logical_plan +01)Inner Join: t1.a = t2.x +02)--Filter: t1.a = Int32(1) OR t1.a = Int32(2) OR Boolean(NULL) +03)----TableScan: t1 projection=[a, b, c] +04)--Filter: t2.x = Int32(1) OR t2.x = Int32(2) OR Boolean(NULL) +05)----TableScan: t2 projection=[x, y, z] + +query IITIIT rowsort +select * from t1 left join t2 on t1.a = t2.x where t2.x in (1, 2, null); +---- +1 10 a 1 100 p +2 20 b 2 200 q + +# NOT IN — also null-rejecting +query TT +explain select * from t1 left join t2 on t1.a = t2.x where t2.x not in (99); +---- +logical_plan +01)Inner Join: t1.a = t2.x +02)--Filter: t1.a != Int32(99) +03)----TableScan: t1 projection=[a, b, c] +04)--Filter: t2.x != Int32(99) +05)----TableScan: t2 projection=[x, y, z] + +### +### BETWEEN tests +### + +# LEFT JOIN + WHERE t2.x BETWEEN ... -> INNER JOIN +query TT +explain select * from t1 left join t2 on t1.a = t2.x where t2.x between 1 and 3; +---- +logical_plan +01)Inner Join: t1.a = t2.x +02)--Filter: t1.a >= Int32(1) AND t1.a <= Int32(3) +03)----TableScan: t1 projection=[a, b, c] +04)--Filter: t2.x >= Int32(1) AND t2.x <= Int32(3) +05)----TableScan: t2 projection=[x, y, z] + +query IITIIT rowsort +select * from t1 left join t2 on t1.a = t2.x where t2.x between 1 and 3; +---- +1 10 a 1 100 p +2 20 b 2 200 q + +# RIGHT JOIN + WHERE t1.a BETWEEN ... -> INNER JOIN +query TT +explain select * from t1 right join t2 on t1.a = t2.x where t1.a between 1 and 2; +---- +logical_plan +01)Inner Join: t1.a = t2.x +02)--Filter: t1.a >= Int32(1) AND t1.a <= Int32(2) +03)----TableScan: t1 projection=[a, b, c] +04)--Filter: t2.x >= Int32(1) AND t2.x <= Int32(2) +05)----TableScan: t2 projection=[x, y, z] + +query IITIIT rowsort +select * from t1 right join t2 on t1.a = t2.x where t1.a between 1 and 2; +---- +1 10 a 1 100 p +2 20 b 2 200 q + +### +### Negative tests: join should NOT be converted +### + +# IN on preserved side of LEFT JOIN — doesn't help eliminate. +# Filter is still pushed down to the right side via join key inference. +query TT +explain select * from t1 left join t2 on t1.a = t2.x where t1.a in (1, 2, 3); +---- +logical_plan +01)Left Join: t1.a = t2.x +02)--Filter: t1.a = Int32(1) OR t1.a = Int32(2) OR t1.a = Int32(3) +03)----TableScan: t1 projection=[a, b, c] +04)--Filter: t2.x = Int32(1) OR t2.x = Int32(2) OR t2.x = Int32(3) +05)----TableScan: t2 projection=[x, y, z] + +query IITIIT rowsort +select * from t1 left join t2 on t1.a = t2.x where t1.a in (1, 2, 3); +---- +1 10 a 1 100 p +2 20 b 2 200 q +3 30 c NULL NULL NULL + +# IN(...) OR IS NULL — not null-rejecting, join preserved +query TT +explain select * from t1 left join t2 on t1.a = t2.x where t2.x in (1, 2) or t2.x is null; +---- +logical_plan +01)Filter: t2.x = Int32(1) OR t2.x = Int32(2) OR t2.x IS NULL +02)--Left Join: t1.a = t2.x +03)----TableScan: t1 projection=[a, b, c] +04)----TableScan: t2 projection=[x, y, z] + +query IITIIT rowsort +select * from t1 left join t2 on t1.a = t2.x where t2.x in (1, 2) or t2.x is null; +---- +1 10 a 1 100 p +2 20 b 2 200 q +3 30 c NULL NULL NULL +NULL 40 d NULL NULL NULL + +### +### Cleanup +### + +statement ok +set datafusion.explain.logical_plan_only = false; + +statement ok +drop table t1; + +statement ok +drop table t2;