Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 48 additions & 50 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3429,31 +3429,30 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> {
assert_snapshot!(
pretty_format_batches(&sql_results).unwrap(),
@r"
+---------------+----------------------------------------------------------------------------------------------------------------------------+
| plan_type | plan |
+---------------+----------------------------------------------------------------------------------------------------------------------------+
| logical_plan | Projection: t1.a, t1.b |
| | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) |
| | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true |
| | Left Join: t1.a = __scalar_sq_1.a |
| | TableScan: t1 projection=[a, b] |
| | SubqueryAlias: __scalar_sq_1 |
| | Projection: count(Int64(1)) AS count(*), t2.a, Boolean(true) AS __always_true |
| | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1))]] |
| | TableScan: t2 projection=[a] |
| physical_plan | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] |
| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |
| | ProjectionExec: expr=[a@2 as a, b@3 as b, count(*)@0 as count(*), __always_true@1 as __always_true] |
| | HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@1, a@0)], projection=[count(*)@0, __always_true@2, a@3, b@4] |
| | CoalescePartitionsExec |
| | ProjectionExec: expr=[count(Int64(1))@1 as count(*), a@0 as a, true as __always_true] |
| | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(Int64(1))] |
| | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 |
| | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(Int64(1))] |
| | DataSourceExec: partitions=1, partition_sizes=[1] |
| | DataSourceExec: partitions=1, partition_sizes=[1] |
| | |
+---------------+----------------------------------------------------------------------------------------------------------------------------+
+---------------+--------------------------------------------------------------------------------------------------------------------------+
| plan_type | plan |
+---------------+--------------------------------------------------------------------------------------------------------------------------+
| logical_plan | Projection: t1.a, t1.b |
| | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) |
| | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true |
| | Left Join: t1.a = __scalar_sq_1.a |
| | TableScan: t1 projection=[a, b] |
| | SubqueryAlias: __scalar_sq_1 |
| | Projection: count(Int64(1)) AS count(*), t2.a, Boolean(true) AS __always_true |
| | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1))]] |
| | TableScan: t2 projection=[a] |
| physical_plan | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] |
| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |
| | HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@1, a@0)], projection=[a@3, b@4, count(*)@0, __always_true@2] |
| | CoalescePartitionsExec |
| | ProjectionExec: expr=[count(Int64(1))@1 as count(*), a@0 as a, true as __always_true] |
| | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(Int64(1))] |
| | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 |
| | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(Int64(1))] |
| | DataSourceExec: partitions=1, partition_sizes=[1] |
| | DataSourceExec: partitions=1, partition_sizes=[1] |
| | |
+---------------+--------------------------------------------------------------------------------------------------------------------------+
"
);

Expand Down Expand Up @@ -3485,31 +3484,30 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> {
assert_snapshot!(
pretty_format_batches(&df_results).unwrap(),
@r"
+---------------+----------------------------------------------------------------------------------------------------------------------------+
| plan_type | plan |
+---------------+----------------------------------------------------------------------------------------------------------------------------+
| logical_plan | Projection: t1.a, t1.b |
| | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) |
| | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true |
| | Left Join: t1.a = __scalar_sq_1.a |
| | TableScan: t1 projection=[a, b] |
| | SubqueryAlias: __scalar_sq_1 |
| | Projection: count(*), t2.a, Boolean(true) AS __always_true |
| | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1)) AS count(*)]] |
| | TableScan: t2 projection=[a] |
| physical_plan | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] |
| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |
| | ProjectionExec: expr=[a@2 as a, b@3 as b, count(*)@0 as count(*), __always_true@1 as __always_true] |
| | HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@1, a@0)], projection=[count(*)@0, __always_true@2, a@3, b@4] |
| | CoalescePartitionsExec |
| | ProjectionExec: expr=[count(*)@1 as count(*), a@0 as a, true as __always_true] |
| | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(*)] |
| | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 |
| | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(*)] |
| | DataSourceExec: partitions=1, partition_sizes=[1] |
| | DataSourceExec: partitions=1, partition_sizes=[1] |
| | |
+---------------+----------------------------------------------------------------------------------------------------------------------------+
+---------------+--------------------------------------------------------------------------------------------------------------------------+
| plan_type | plan |
+---------------+--------------------------------------------------------------------------------------------------------------------------+
| logical_plan | Projection: t1.a, t1.b |
| | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) |
| | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true |
| | Left Join: t1.a = __scalar_sq_1.a |
| | TableScan: t1 projection=[a, b] |
| | SubqueryAlias: __scalar_sq_1 |
| | Projection: count(*), t2.a, Boolean(true) AS __always_true |
| | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1)) AS count(*)]] |
| | TableScan: t2 projection=[a] |
| physical_plan | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] |
| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |
| | HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@1, a@0)], projection=[a@3, b@4, count(*)@0, __always_true@2] |
| | CoalescePartitionsExec |
| | ProjectionExec: expr=[count(*)@1 as count(*), a@0 as a, true as __always_true] |
| | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(*)] |
| | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 |
| | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(*)] |
| | DataSourceExec: partitions=1, partition_sizes=[1] |
| | DataSourceExec: partitions=1, partition_sizes=[1] |
| | |
+---------------+--------------------------------------------------------------------------------------------------------------------------+
"
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1313,8 +1313,8 @@ fn test_hash_join_after_projection() -> Result<()> {
assert_snapshot!(
actual,
@r"
ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, c@3 as c_from_right]
HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2, projection=[a@0, b@1, c@2, c@7]
ProjectionExec: expr=[c@0 as c_from_left, b@1 as b_from_left, a@2 as a_from_left, c@3 as c_from_right]
HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2, projection=[c@2, b@1, a@0, c@7]
DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false
DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false
"
Expand Down
61 changes: 42 additions & 19 deletions datafusion/physical-plan/src/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ use datafusion_execution::TaskContext;
use datafusion_expr::ExpressionPlacement;
use datafusion_physical_expr::equivalence::ProjectionMapping;
use datafusion_physical_expr::projection::Projector;
use datafusion_physical_expr::utils::collect_columns;
use datafusion_physical_expr_common::physical_expr::{PhysicalExprRef, fmt_sql};
use datafusion_physical_expr_common::sort_expr::{
LexOrdering, LexRequirement, PhysicalSortExpr,
Expand Down Expand Up @@ -607,13 +606,7 @@ pub fn try_embed_projection<Exec: EmbeddedProjection + 'static>(
return Ok(None);
};

// If the projection indices is the same as the input columns, we don't need to embed the projection to hash join.
// Check the projection_index is 0..n-1 and the length of projection_index is the same as the length of execution_plan schema fields.
if projection_index.len() == projection_index.last().unwrap() + 1
&& projection_index.len() == execution_plan.schema().fields().len()
{
return Ok(None);
}
let columns_reduced = projection_index.len() < execution_plan.schema().fields().len();

let new_execution_plan =
Arc::new(execution_plan.with_projection(Some(projection_index.to_vec()))?);
Expand Down Expand Up @@ -648,9 +641,16 @@ pub fn try_embed_projection<Exec: EmbeddedProjection + 'static>(
Arc::clone(&new_execution_plan) as _,
)?);
if is_projection_removable(&new_projection) {
// Residual is identity — embedding fully absorbed the projection.
Ok(Some(new_execution_plan))
} else {
} else if columns_reduced {
// Embedding reduced columns even though a residual is still needed
// for renames or expressions — worth keeping.
Ok(Some(new_projection))
} else {
// No columns eliminated and residual still needed — embedding just
// adds an unnecessary column reorder inside the operator.
Ok(None)
}
}

Expand Down Expand Up @@ -1080,15 +1080,37 @@ fn try_unifying_projections(

/// Collect all column indices from the given projection expressions.
fn collect_column_indices(exprs: &[ProjectionExpr]) -> Vec<usize> {
// Collect indices and remove duplicates.
let mut indices = exprs
.iter()
.flat_map(|proj_expr| collect_columns(&proj_expr.expr))
.map(|x| x.index())
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect::<Vec<_>>();
indices.sort();
// Collect column indices in a deterministic order that preserves the
// projection's column ordering. For simple Column expressions, we use
// the column index directly. For complex expressions, we walk the
// expression tree to collect column references in traversal order.
// This allows the embedded projection to match the desired output
// column order, avoiding a residual ProjectionExec.
let mut seen = std::collections::HashSet::new();
let mut indices = Vec::new();
for proj_expr in exprs {
if let Some(col) = proj_expr.expr.as_any().downcast_ref::<Column>() {
// Simple column reference: preserve projection order.
if seen.insert(col.index()) {
indices.push(col.index());
}
} else {
// Complex expression: collect all referenced columns in
// expression tree traversal order (deterministic) to preserve
// the natural ordering of column references.
proj_expr
.expr
.apply(|expr| {
if let Some(col) = expr.as_any().downcast_ref::<Column>()
&& seen.insert(col.index())
{
indices.push(col.index());
}
Ok(TreeNodeRecursion::Continue)
})
.expect("closure always returns OK");
}
}
indices
}

Expand Down Expand Up @@ -1202,7 +1224,8 @@ mod tests {
expr,
alias: "b-(1+a)".to_string(),
}]);
assert_eq!(column_indices, vec![1, 7]);
// Tree traversal order: b@7 is visited before a@1
assert_eq!(column_indices, vec![7, 1]);
Ok(())
}

Expand Down
Loading
Loading